Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use approved SocketAddress API instead of direct internal access #89841

Merged
merged 4 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ internal static partial class Interop
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Bind")]
internal static unsafe partial Error Bind(SafeHandle socket, ProtocolType socketProtocolType, byte* socketAddress, int socketAddressLen);
private static partial Error Bind(SafeHandle socket, ProtocolType socketProtocolType, ReadOnlySpan<byte> socketAddress, int socketAddressLen);

internal static Error Bind(
SafeHandle socket, ProtocolType socketProtocolType, ReadOnlySpan<byte> socketAddress)
=> Bind(socket, socketProtocolType, socketAddress, socketAddress.Length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
using System;
using System.Net.Sockets;
using System.Runtime.InteropServices;
#if !SYSTEM_NET_SOCKETS_DLL
using SocketType = System.Net.Internals.SocketType;
#endif

internal static partial class Interop
{
Expand All @@ -15,7 +12,7 @@ internal static partial class Winsock
[LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true, StringMarshalling = StringMarshalling.Utf16)]
internal static partial IntPtr WSASocketW(
AddressFamily addressFamily,
SocketType socketType,
int socketType,
int protocolType,
IntPtr protocolInfo,
int group,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Net.Sockets;

Expand All @@ -9,9 +10,13 @@ internal static partial class Interop
internal static partial class Winsock
{
[LibraryImport(Interop.Libraries.Ws2_32, SetLastError = true)]
internal static partial SocketError bind(
private static partial SocketError bind(
SafeSocketHandle socketHandle,
byte[] socketAddress,
ReadOnlySpan<byte> socketAddress,
int socketAddressSize);

internal static SocketError bind(
SafeSocketHandle socketHandle,
ReadOnlySpan<byte> socketAddress) => bind(socketHandle, socketAddress, socketAddress.Length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace System.Net.Sockets
{
internal static class IPEndPointExtensions
internal static partial class IPEndPointExtensions
{
public static IPAddress GetIPAddress(ReadOnlySpan<byte> socketAddressBuffer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace System.Net.Sockets
{
internal static class IPEndPointExtensions
internal static partial class IPEndPointExtensions
{
public static Internals.SocketAddress Serialize(EndPoint endpoint)
{
Expand Down
50 changes: 25 additions & 25 deletions src/libraries/Common/src/System/Net/SocketAddress.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class SocketAddress : System.IEquatable<SocketAddress>
internal static readonly int MaxAddressSize = SocketAddressPal.MaxAddressSize;
#pragma warning restore CA1802

internal int InternalSize;
internal byte[] InternalBuffer;
private int _size;
private byte[] _buffer;

private const int MinSize = 2;
private const int DataOffset = 2;
Expand All @@ -39,21 +39,21 @@ public AddressFamily Family
{
get
{
return SocketAddressPal.GetAddressFamily(InternalBuffer);
return SocketAddressPal.GetAddressFamily(_buffer);
}
}

public int Size
{
get
{
return InternalSize;
return _size;
}
set
{
ArgumentOutOfRangeException.ThrowIfGreaterThan(value, InternalBuffer.Length);
ArgumentOutOfRangeException.ThrowIfGreaterThan(value, _buffer.Length);
ArgumentOutOfRangeException.ThrowIfLessThan(value, MinSize);
InternalSize = value;
_size = value;
}
}

Expand All @@ -69,15 +69,15 @@ public byte this[int offset]
{
throw new IndexOutOfRangeException();
}
return InternalBuffer[offset];
return _buffer[offset];
}
set
{
if ((uint)offset >= (uint)Size)
{
throw new IndexOutOfRangeException();
}
InternalBuffer[offset] = value;
_buffer[offset] = value;
}
}

Expand All @@ -97,11 +97,11 @@ public SocketAddress(AddressFamily family, int size)
{
ArgumentOutOfRangeException.ThrowIfLessThan(size, MinSize);

InternalSize = size;
InternalBuffer = new byte[size];
InternalBuffer[0] = (byte)InternalSize;
_size = size;
_buffer = new byte[size];
_buffer[0] = (byte)_size;

SocketAddressPal.SetAddressFamily(InternalBuffer, family);
SocketAddressPal.SetAddressFamily(_buffer, family);
}

internal SocketAddress(IPAddress ipAddress)
Expand All @@ -110,15 +110,15 @@ internal SocketAddress(IPAddress ipAddress)
{

// No Port.
SocketAddressPal.SetPort(InternalBuffer, 0);
SocketAddressPal.SetPort(_buffer, 0);

if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6)
{
Span<byte> addressBytes = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
ipAddress.TryWriteBytes(addressBytes, out int bytesWritten);
Debug.Assert(bytesWritten == IPAddressParserStatics.IPv6AddressBytes);

SocketAddressPal.SetIPv6Address(InternalBuffer, addressBytes, (uint)ipAddress.ScopeId);
SocketAddressPal.SetIPv6Address(_buffer, addressBytes, (uint)ipAddress.ScopeId);
}
else
{
Expand All @@ -127,21 +127,21 @@ internal SocketAddress(IPAddress ipAddress)
#pragma warning restore CS0618

Debug.Assert(ipAddress.AddressFamily == AddressFamily.InterNetwork);
SocketAddressPal.SetIPv4Address(InternalBuffer, address);
SocketAddressPal.SetIPv4Address(_buffer, address);
}
}

internal SocketAddress(IPAddress ipaddress, int port)
: this(ipaddress)
{
SocketAddressPal.SetPort(InternalBuffer, unchecked((ushort)port));
SocketAddressPal.SetPort(_buffer, unchecked((ushort)port));
}

internal SocketAddress(AddressFamily addressFamily, ReadOnlySpan<byte> buffer)
{
InternalBuffer = buffer.ToArray();
InternalSize = InternalBuffer.Length;
SocketAddressPal.SetAddressFamily(InternalBuffer, addressFamily);
_buffer = buffer.ToArray();
_size = _buffer.Length;
SocketAddressPal.SetAddressFamily(_buffer, addressFamily);
}

/// <summary>This represents underlying memory that can be passed to native OS calls.</summary>
Expand All @@ -152,7 +152,7 @@ public Memory<byte> Buffer
{
get
{
return new Memory<byte>(InternalBuffer, 0, InternalSize);
return new Memory<byte>(_buffer, 0, _size);
}
}

Expand All @@ -164,14 +164,14 @@ internal IPAddress GetIPAddress()

Span<byte> address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes];
uint scope;
SocketAddressPal.GetIPv6Address(InternalBuffer, address, out scope);
SocketAddressPal.GetIPv6Address(_buffer, address, out scope);

return new IPAddress(address, (long)scope);
}
else if (Family == AddressFamily.InterNetwork)
{
Debug.Assert(Size >= IPv4AddressSize);
long address = (long)SocketAddressPal.GetIPv4Address(InternalBuffer) & 0x0FFFFFFFF;
long address = (long)SocketAddressPal.GetIPv4Address(_buffer) & 0x0FFFFFFFF;
return new IPAddress(address);
}
else
Expand All @@ -184,7 +184,7 @@ internal IPAddress GetIPAddress()
}
}

internal int GetPort() => (int)SocketAddressPal.GetPort(InternalBuffer);
internal int GetPort() => (int)SocketAddressPal.GetPort(_buffer);

internal IPEndPoint GetIPEndPoint()
{
Expand All @@ -199,7 +199,7 @@ public override bool Equals(object? comparand) =>
public override int GetHashCode()
{
HashCode hash = default;
hash.AddBytes(new ReadOnlySpan<byte>(InternalBuffer, 0, InternalSize));
hash.AddBytes(new ReadOnlySpan<byte>(_buffer, 0, _size));
return hash.ToHashCode();
}

Expand Down Expand Up @@ -234,7 +234,7 @@ public override string ToString()
result[length++] = ':';
result[length++] = '{';

byte[] buffer = InternalBuffer;
byte[] buffer = _buffer;
for (int i = DataOffset; i < Size; i++)
{
if (i > DataOffset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace System.Net
{
internal static partial class SocketProtocolSupportPal
{
private const int DgramSocketType = 2;
private static unsafe bool IsSupported(AddressFamily af)
{
// Check for AF_UNIX on iOS/tvOS. The OS claims to support this, but returns EPERM on bind.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

using System.Net.Sockets;
using System.Runtime.InteropServices;
#if !SYSTEM_NET_SOCKETS_DLL
using SocketType = System.Net.Internals.SocketType;
#endif

namespace System.Net
{
Expand All @@ -19,7 +16,7 @@ private static bool IsSupported(AddressFamily af)
IntPtr socket = INVALID_SOCKET;
try
{
socket = Interop.Winsock.WSASocketW(af, SocketType.Stream, 0, IntPtr.Zero, 0, (int)Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);
socket = Interop.Winsock.WSASocketW(af, DgramSocketType, 0, IntPtr.Zero, 0, (int)Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);
return
socket != INVALID_SOCKET ||
(SocketError)Marshal.GetLastPInvokeError() != SocketError.AddressFamilyNotSupported;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ internal static partial class SocketProtocolSupportPal
public static bool OSSupportsIPv4 { get; } = IsSupported(AddressFamily.InterNetwork);
public static bool OSSupportsUnixDomainSockets { get; } = IsSupported(AddressFamily.Unix);

private const int DgramSocketType = 2;

private static bool IsIPv6Disabled()
{
// First check for the AppContext switch, giving it priority over the environment variable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@
<Compile Include="$(CommonPath)System\Net\InternalException.cs"
Link="Common\System\Net\InternalException.cs" />
<!-- System.Net common -->
<Compile Include="$(CommonPath)System\Net\Sockets\ProtocolType.cs"
Link="Common\System\Net\Sockets\ProtocolType.cs" />
<Compile Include="$(CommonPath)System\Net\Sockets\SocketType.cs"
Link="Common\System\Net\Sockets\SocketType.cs" />
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it later if you'd prefer, but these files should also move out of common into the appropriate library folder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I was thinking about SocketAddress.cs as well once the changes are over. Probably separate PR with just moves/renames.

<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointStatics.cs"
Link="Common\System\Net\IPEndPointStatics.cs" />
<Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.cs"
Link="Common\System\Net\SocketProtocolSupportPal.cs" />
</ItemGroup>
Expand All @@ -37,9 +31,7 @@
<!-- Debug only -->
<Compile Include="$(CommonPath)System\Net\DebugSafeHandle.cs"
Link="Common\System\Net\DebugSafeHandle.cs" />
<!-- System.Net.Internals -->
<Compile Include="$(CommonPath)System\Net\Internals\IPAddressExtensions.cs"
Link="Common\System\Net\Internals\IPAddressExtensions.cs" />
<!-- System.Net common -->
<Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.Windows.cs"
Link="Common\System\Net\SocketProtocolSupportPal.Windows" />
<Compile Include="$(CommonPath)System\Net\SocketAddressPal.Windows.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Net.Internals;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Internals;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Text;
Expand Down
14 changes: 4 additions & 10 deletions src/libraries/System.Net.Ping/src/System.Net.Ping.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,8 @@
<Compile Include="$(CommonPath)System\Obsoletions.cs"
Link="Common\System\Obsoletions.cs" />
<!-- System.Net Common -->
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.cs"
Link="Common\System\Net\SocketProtocolSupportPal.cs" />
<Compile Include="$(CommonPath)System\Net\InternalException.cs"
Link="Common\System\Net\InternalException.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointExtensions.cs"
Link="Common\System\Net\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\Sockets\SocketType.cs"
Link="Common\System\Net\Sockets\SocketType.cs" />
</ItemGroup>
<ItemGroup Condition="('$(TargetPlatformIdentifier)' != '' and '$(TargetPlatformIdentifier)' != 'windows')">
<Compile Include="System\Net\NetworkInformation\IcmpV4MessageConstants.cs" />
Expand All @@ -40,8 +32,6 @@
<!-- System.Net Common -->
<Compile Include="$(CommonPath)System\Net\RawSocketPermissions.cs"
Link="Common\System\Net\RawSocketPermissions.cs" />
<Compile Include="$(CommonPath)System\Net\SocketAddressPal.Unix.cs"
Link="Common\System\Net\SocketAddressPal.Unix.cs" />
<Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.Unix.cs"
Link="Common\System\Net\SocketProtocolSupportPal.Unix.cs" />
<Compile Include="$(CommonPath)System\Net\NetworkInformation\UnixCommandLinePing.cs"
Expand Down Expand Up @@ -72,6 +62,10 @@
<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'windows'">
<Compile Include="System\Net\NetworkInformation\Ping.Windows.cs" />
<!-- System.Net Common -->
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointExtensions.cs"
Link="Common\System\Net\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\SocketAddressPal.Windows.cs"
Link="Common\System\Net\SocketAddressPal.Windows.cs" />
<Compile Include="$(CommonPath)System\Net\SocketProtocolSupportPal.Windows.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
<Compile Include="$(CommonPath)System\Net\DebugSafeHandleMinusOneIsInvalid.cs"
Link="Common\System\Net\DebugSafeHandleMinusOneIsInvalid.cs" />
<!-- System.Net common -->
<Compile Include="$(CommonPath)System\Net\IPEndPointExtensions.cs"
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
Link="Common\System\Net\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointStatics.cs"
Link="Common\System\Net\IPEndPointStatics.cs" />
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,18 @@ public Socket(SocketInformation socketInformation)
IPEndPoint ep = new IPEndPoint(tempAddress, 0);

Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep);
int size = socketAddress.Buffer.Length;
unsafe
{
fixed (byte* bufferPtr = socketAddress.InternalBuffer)
fixed (int* sizePtr = &socketAddress.InternalSize)
fixed (byte* bufferPtr = socketAddress.Buffer.Span)
{
errorCode = SocketPal.GetSockName(_handle, bufferPtr, sizePtr);
errorCode = SocketPal.GetSockName(_handle, bufferPtr, &size);
}
}

if (errorCode == SocketError.Success)
{
socketAddress.Size = size;
_rightEndPoint = ep.Create(socketAddress);
}
else if (errorCode == SocketError.InvalidArgument)
Expand Down
Loading