Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ReceiveFromAsync and SendToAsync with SocketAddress overload #90086

Merged
merged 11 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ internal static partial class Interop
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Bind")]
private static partial Error Bind(SafeHandle socket, ProtocolType socketProtocolType, ReadOnlySpan<byte> socketAddress, int socketAddressLen);
private static partial Error Bind(SafeHandle socket, int socketProtocolType, ReadOnlySpan<byte> socketAddress, int socketAddressLen);
wfurt marked this conversation as resolved.
Show resolved Hide resolved

internal static Error Bind(
SafeHandle socket, ProtocolType socketProtocolType, ReadOnlySpan<byte> socketAddress)
=> Bind(socket, socketProtocolType, socketAddress, socketAddress.Length);
=> Bind(socket, (int)socketProtocolType, socketAddress, socketAddress.Length);
}
}
25 changes: 25 additions & 0 deletions src/libraries/Common/src/System/Net/IPEndPointExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,30 @@ public static void Serialize(this IPEndPoint endPoint, Span<byte> destination)
SetIPAddress(destination, endPoint.Address);
SocketAddressPal.SetPort(destination, (ushort)endPoint.Port);
}

public static bool Equals(this IPEndPoint endPoint, ReadOnlySpan<byte> socketAddressBuffer)
{
if (socketAddressBuffer.Length >= SocketAddress.GetMaximumAddressSize(endPoint.AddressFamily) &&
endPoint.AddressFamily == SocketAddressPal.GetAddressFamily(socketAddressBuffer) &&
endPoint.Port == (int)SocketAddressPal.GetPort(socketAddressBuffer))
{
if (endPoint.AddressFamily == AddressFamily.InterNetwork)
{
#pragma warning disable CS0618
return endPoint.Address.Address == (long)SocketAddressPal.GetIPv4Address(socketAddressBuffer);
#pragma warning restore CS0618
}
else
{
Span<byte> addressBuffer1 = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
Span<byte> addressBuffer2 = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
SocketAddressPal.GetIPv6Address(socketAddressBuffer, addressBuffer1, out uint scopeid);
endPoint.Address.TryWriteBytes(addressBuffer2, out _);
return endPoint.Address.ScopeId == (long)scopeid && addressBuffer1.SequenceEqual(addressBuffer2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we compare the scope id before doing the TryWriteBytes above?

}
}

return false;
}
}
}
18 changes: 1 addition & 17 deletions src/libraries/Common/src/System/Net/SocketAddress.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers.Binary;
using System.Diagnostics;
using System.Globalization;
using System.Net.Sockets;
using System.Text;

#if SYSTEM_NET_PRIMITIVES_DLL
namespace System.Net
#else
namespace System.Net.Internals
#endif
{
// This class is used when subclassing EndPoint, and provides indication
// on how to format the memory buffers that the platform uses for network addresses.
#if SYSTEM_NET_PRIMITIVES_DLL
public
#else
internal sealed
#endif
class SocketAddress : System.IEquatable<SocketAddress>
public class SocketAddress : System.IEquatable<SocketAddress>
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
#pragma warning disable CA1802 // these could be const on Windows but need to be static readonly for Unix
internal static readonly int IPv6AddressSize = SocketAddressPal.IPv6AddressSize;
Expand Down Expand Up @@ -176,11 +164,7 @@ internal IPAddress GetIPAddress()
}
else
{
#if SYSTEM_NET_PRIMITIVES_DLL
throw new SocketException(SocketError.AddressFamilyNotSupported);
#else
throw new SocketException((int)SocketError.AddressFamilyNotSupported);
#endif
}
}

Expand Down
40 changes: 40 additions & 0 deletions src/libraries/Common/src/System/Net/SocketAddressExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Net;

namespace System.Net.Sockets
{
internal static partial class SocketAddressExtensions
{
public static IPAddress GetIPAddress(this SocketAddress socketAddress) => IPEndPointExtensions.GetIPAddress(socketAddress.Buffer.Span);
public static int GetPort(this SocketAddress socketAddress)
{
Debug.Assert(socketAddress.Family == AddressFamily.InterNetwork || socketAddress.Family == AddressFamily.InterNetworkV6);
return (int)SocketAddressPal.GetPort(socketAddress.Buffer.Span);
}

public static IPEndPoint GetIPEndPoint(this SocketAddress socketAddress)
{
return new IPEndPoint(socketAddress.GetIPAddress(), socketAddress.GetPort());
}

public static bool Equals(this SocketAddress socketAddress, EndPoint? endPoint)
{
if (endPoint is IPEndPoint ipe)
{
if (socketAddress.Family == endPoint.AddressFamily)
{
return ipe.Equals(socketAddress.Buffer.Span);
}
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved

// We could serialize other EndPoints and compare socket addresses.
// But that would do two allocations and is probably as expensive as
// allocating new EndPoint.
// This may change if https://github.com/dotnet/runtime/issues/78993 is done
return false;
}
}
}
6 changes: 0 additions & 6 deletions src/libraries/Common/src/System/Net/Sockets/ProtocolType.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#if SYSTEM_NET_SOCKETS_DLL
namespace System.Net.Sockets
{
public
#else
namespace System.Net.Internals
{
internal
#endif
// Specifies the protocols that the Socket class supports.
enum ProtocolType
{
Expand Down
6 changes: 0 additions & 6 deletions src/libraries/Common/src/System/Net/Sockets/SocketType.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#if SYSTEM_NET_SOCKETS_DLL
namespace System.Net.Sockets
{
public
#else
namespace System.Net.Internals
{
internal
#endif
// Specifies the type of socket an instance of the System.Net.Sockets.Socket class represents.
enum SocketType
{
Expand Down
4 changes: 3 additions & 1 deletion src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,12 @@ public void Listen(int backlog) { }
public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; }
public int ReceiveFrom(System.Span<byte> buffer, ref System.Net.EndPoint remoteEP) { throw null; }
public int ReceiveFrom(System.Span<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; }
public int ReceiveFrom(System.Span<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedSocketAddress) { throw null; }
public int ReceiveFrom(System.Span<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedAddress) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.ArraySegment<byte> buffer, System.Net.EndPoint remoteEndPoint) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.Memory<byte> buffer, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask<int> ReceiveFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress receivedAddress, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
public int ReceiveMessageFrom(System.Span<byte> buffer, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
Expand Down Expand Up @@ -451,6 +452,7 @@ public void SendFile(string? fileName, System.ReadOnlySpan<byte> preBuffer, Syst
public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendToAsync(System.ReadOnlyMemory<byte> buffer, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendToAsync(System.ReadOnlyMemory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendToAsync(System.ReadOnlyMemory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.SocketAddress socketAddress, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
public void SetIPProtectionLevel(System.Net.Sockets.IPProtectionLevel level) { }
public void SetRawSocketOption(int optionLevel, int optionName, System.ReadOnlySpan<byte> optionValue) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,12 @@
Link="Common\System\Net\ExceptionCheck.cs" />
<Compile Include="$(CommonPath)System\Net\RangeValidationHelpers.cs"
Link="Common\System\Net\RangeValidationHelpers.cs" />
<Compile Include="$(CommonPath)System\Net\SocketAddress.cs"
Link="Common\System\Net\SocketAddress.cs" />
<Compile Include="$(CommonPath)System\Net\TcpValidationHelpers.cs"
Link="Common\System\Net\TcpValidationHelpers.cs" />
<Compile Include="$(CommonPath)System\Net\SocketAddressExtensions.cs"
Link="Common\System\Net\SocketAddressExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.cs"
Link="Common\System\Net\SocketProtocolSupportPal.cs" />
<!-- System.Net.Internals -->
<Compile Include="$(CommonPath)System\Net\Internals\IPEndPointExtensions.cs"
Link="Common\System\Net\Internals\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\Internals\IPAddressExtensions.cs"
Link="Common\System\Net\Internals\IPAddressExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\Sockets\SocketExceptionFactory.cs"
Link="Common\System\Net\Sockets\SocketExceptionFactory.cs" />
<Compile Include="$(CommonPath)System\Net\Sockets\ProtocolFamily.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,42 @@ public ValueTask<SocketReceiveFromResult> ReceiveFromAsync(Memory<byte> buffer,
return saea.ReceiveFromAsync(this, cancellationToken);
}

/// <summary>
/// Receives data and returns the endpoint of the sending host.
/// </summary>
/// <param name="buffer">The buffer for the received data.</param>
/// <param name="socketFlags">A bitwise combination of SocketFlags values that will be used when receiving the data.</param>
/// <param name="receivedAddress">An <see cref="SocketAddress"/>, that will be updated with value of the remote peer.</param>
/// <param name="cancellationToken">A cancellation token that can be used to signal the asynchronous operation should be canceled.</param>
/// <returns>An asynchronous task that completes with a <see cref="SocketReceiveFromResult"/> containing the number of bytes received and the endpoint of the sending host.</returns>
public ValueTask<int> ReceiveFromAsync(Memory<byte> buffer, SocketFlags socketFlags, SocketAddress receivedAddress, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
ArgumentNullException.ThrowIfNull(receivedAddress, nameof(receivedAddress));

if (receivedAddress.Size < SocketAddress.GetMaximumAddressSize(AddressFamily))
{
throw new ArgumentOutOfRangeException(nameof(receivedAddress), SR.net_sockets_address_small);
}

if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled<int>(cancellationToken);
}

AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: true);

Debug.Assert(saea.BufferList == null);
saea.SetBuffer(buffer);
saea.SocketFlags = socketFlags;
saea.RemoteEndPoint = null;
saea._socketAddress = receivedAddress;
saea.WrapExceptionsForNetworkStream = false;
return saea.ReceiveFromSaAsync(this, cancellationToken);
}

/// <summary>
/// Receives data and returns additional information about the sender of the message.
/// </summary>
Expand Down Expand Up @@ -636,11 +672,41 @@ public ValueTask<int> SendToAsync(ReadOnlyMemory<byte> buffer, SocketFlags socke
Debug.Assert(saea.BufferList == null);
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea._socketAddress = null;
saea.RemoteEndPoint = remoteEP;
saea.WrapExceptionsForNetworkStream = false;
return saea.SendToAsync(this, cancellationToken);
}

/// <summary>
/// Sends data to the specified remote host.
/// </summary>
/// <param name="buffer">The buffer for the data to send.</param>
/// <param name="socketFlags">A bitwise combination of SocketFlags values that will be used when sending the data.</param>
/// <param name="socketAddress">The remote host to which to send the data.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the number of bytes sent.</returns>
public ValueTask<int> SendToAsync(ReadOnlyMemory<byte> buffer, SocketFlags socketFlags, SocketAddress socketAddress, CancellationToken cancellationToken = default)
{
wfurt marked this conversation as resolved.
Show resolved Hide resolved
ArgumentNullException.ThrowIfNull(socketAddress);

if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled<int>(cancellationToken);
}

AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

Debug.Assert(saea.BufferList == null);
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea._socketAddress = socketAddress;
saea.WrapExceptionsForNetworkStream = false;
return saea.SendToAsync(this, cancellationToken);
}

/// <summary>
/// Sends the file <paramref name="fileName"/> to a connected <see cref="Socket"/> object.
/// </summary>
Expand Down Expand Up @@ -1019,6 +1085,24 @@ public ValueTask<SocketReceiveFromResult> ReceiveFromAsync(Socket socket, Cancel
ValueTask.FromException<SocketReceiveFromResult>(CreateException(error));
}

internal ValueTask<int> ReceiveFromSaAsync(Socket socket, CancellationToken cancellationToken)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
if (socket.ReceiveFromAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<int>(this, _mrvtsc.Version);
}

int bytesTransferred = BytesTransferred;
SocketError error = SocketError;

ReleaseForSyncCompletion();

return error == SocketError.Success ?
new ValueTask<int>(bytesTransferred) :
ValueTask.FromException<int>(CreateException(error));
}

public ValueTask<SocketReceiveMessageFromResult> ReceiveMessageFromAsync(Socket socket, CancellationToken cancellationToken)
{
if (socket.ReceiveMessageFromAsync(this, cancellationToken))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ internal void ReplaceHandleIfNecessaryAfterFailedConnect() { /* nop on Windows *
private sealed class CachedSerializedEndPoint
{
public readonly IPEndPoint IPEndPoint;
public readonly Internals.SocketAddress SocketAddress;
public readonly SocketAddress SocketAddress;

public CachedSerializedEndPoint(IPAddress address)
{
IPEndPoint = new IPEndPoint(address, 0);
SocketAddress = IPEndPointExtensions.Serialize(IPEndPoint);
SocketAddress = IPEndPoint.Serialize();
}
}

Expand Down Expand Up @@ -70,7 +70,7 @@ public Socket(SocketInformation socketInformation)
IPAddress tempAddress = _addressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any;
IPEndPoint ep = new IPEndPoint(tempAddress, 0);

Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep);
SocketAddress socketAddress = ep.Serialize();
int size = socketAddress.Buffer.Length;
unsafe
{
Expand Down
Loading