Skip to content

Commit

Permalink
Add Socket.OSSupportsUnixDomainSocket (#32160)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Feb 15, 2020
1 parent 9fedc93 commit 03fc3ab
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,29 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Net.Internals;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;

namespace System.Net
{
internal class SocketProtocolSupportPal
internal static class SocketProtocolSupportPal
{
private static bool s_ipv4 = true;
private static bool s_ipv6 = true;
public static bool OSSupportsIPv6 { get; } = IsSupported(AddressFamily.InterNetworkV6);
public static bool OSSupportsIPv4 { get; } = IsSupported(AddressFamily.InterNetwork);
public static bool OSSupportsUnixDomainSockets { get; } = IsSupported(AddressFamily.Unix);

private static bool s_initialized;
private static readonly object s_initializedLock = new object();

public static bool OSSupportsIPv6
{
get
{
EnsureInitialized();
return s_ipv6;
}
}

public static bool OSSupportsIPv4
{
get
{
EnsureInitialized();
return s_ipv4;
}
}

private static void EnsureInitialized()
{
if (!Volatile.Read(ref s_initialized))
{
lock (s_initializedLock)
{
if (!s_initialized)
{
s_ipv4 = IsProtocolSupported(AddressFamily.InterNetwork);
s_ipv6 = IsProtocolSupported(AddressFamily.InterNetworkV6);

Volatile.Write(ref s_initialized, true);
}
}
}
}

private static unsafe bool IsProtocolSupported(AddressFamily af)
private static unsafe bool IsSupported(AddressFamily af)
{
IntPtr socket = (IntPtr)(-1);
IntPtr invalid = (IntPtr)(-1);
IntPtr socket = invalid;
try
{
Interop.Error err = Interop.Sys.Socket(af, SocketType.Dgram, (ProtocolType)0, &socket);
return err != Interop.Error.EAFNOSUPPORT;
return Interop.Sys.Socket(af, SocketType.Dgram, 0, &socket) != Interop.Error.EAFNOSUPPORT;
}
finally
{
if (socket != (IntPtr)(-1))
if (socket == invalid)
{
Interop.Sys.Close(socket);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,94 +2,38 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
#if !SYSTEM_NET_SOCKETS_DLL
using SocketType = System.Net.Internals.SocketType;
#endif

namespace System.Net
{
internal class SocketProtocolSupportPal
internal static class SocketProtocolSupportPal
{
private static bool s_ipv4 = true;
private static bool s_ipv6 = true;
public static bool OSSupportsIPv6 { get; } = IsSupported(AddressFamily.InterNetworkV6);
public static bool OSSupportsIPv4 { get; } = IsSupported(AddressFamily.InterNetwork);
public static bool OSSupportsUnixDomainSockets { get; } = IsSupported(AddressFamily.Unix);

private static bool s_initialized;
private static readonly object s_initializedLock = new object();

public static bool OSSupportsIPv6
private static bool IsSupported(AddressFamily af)
{
get
{
EnsureInitialized();
return s_ipv6;
}
}

public static bool OSSupportsIPv4
{
get
{
EnsureInitialized();
return s_ipv4;
}
}

private static void EnsureInitialized()
{
if (!Volatile.Read(ref s_initialized))
{
lock (s_initializedLock)
{
if (!s_initialized)
{
s_ipv4 = IsProtocolSupported(AddressFamily.InterNetwork);
s_ipv6 = IsProtocolSupported(AddressFamily.InterNetworkV6);

Volatile.Write(ref s_initialized, true);
}
}
}
}

private static bool IsProtocolSupported(AddressFamily af)
{
SocketError errorCode;
IntPtr s = IntPtr.Zero;
bool ret = true;

IntPtr INVALID_SOCKET = (IntPtr)(-1);
IntPtr socket = INVALID_SOCKET;
try
{
s = Interop.Winsock.WSASocketW(af, SocketType.Dgram, 0, IntPtr.Zero, 0, (int)Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);

if (s == IntPtr.Zero)
{
errorCode = (SocketError)Marshal.GetLastWin32Error();
if (errorCode == SocketError.AddressFamilyNotSupported)
{
ret = false;
}
}
socket = Interop.Winsock.WSASocketW(af, SocketType.Stream, 0, IntPtr.Zero, 0, (int)Interop.Winsock.SocketConstructorFlags.WSA_FLAG_NO_HANDLE_INHERIT);
return
socket != INVALID_SOCKET ||
(SocketError)Marshal.GetLastWin32Error() != SocketError.AddressFamilyNotSupported;
}
finally
{
if (s != IntPtr.Zero)
if (socket != INVALID_SOCKET)
{
SocketError closeResult = Interop.Winsock.closesocket(s);
#if DEBUG
if (closeResult != SocketError.Success)
{
errorCode = (SocketError)Marshal.GetLastWin32Error();
Debug.Fail("Failed to detect " + af.ToString() + " protocol: " + errorCode.ToString());
}
#endif
Interop.Winsock.closesocket(socket);
}
}

return ret;
}
}
}
1 change: 1 addition & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.Proto
public bool NoDelay { get { throw null; } set { } }
public static bool OSSupportsIPv4 { get { throw null; } }
public static bool OSSupportsIPv6 { get { throw null; } }
public static bool OSSupportsUnixDomainSockets { get { throw null; } }
public System.Net.Sockets.ProtocolType ProtocolType { get { throw null; } }
public int ReceiveBufferSize { get { throw null; } set { } }
public int ReceiveTimeout { get { throw null; } set { } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ public static bool OSSupportsIPv6
}
}

public static bool OSSupportsUnixDomainSockets
{
get
{
InitializeSockets();
return SocketProtocolSupportPal.OSSupportsUnixDomainSockets;
}
}

// Gets the amount of data pending in the network's input buffer that can be
// read from the socket.
public int Available
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,6 @@ public sealed partial class UnixDomainSocketEndPoint : EndPoint
{
private const AddressFamily EndPointAddressFamily = AddressFamily.Unix;

private static readonly Encoding s_pathEncoding = Encoding.UTF8;
private static readonly Lazy<bool> s_udsSupported = new Lazy<bool>(() =>
{
try
{
new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified).Dispose();
return true;
}
catch
{
return false;
}
});

private readonly string _path;
private readonly byte[] _encodedPath;

Expand All @@ -39,7 +25,7 @@ public UnixDomainSocketEndPoint(string path)
// Pathname socket addresses should be null-terminated.
// Linux abstract socket addresses start with a zero byte, they must not be null-terminated.
bool isAbstract = IsAbstract(path);
int bufferLength = s_pathEncoding.GetByteCount(path);
int bufferLength = Encoding.UTF8.GetByteCount(path);
if (!isAbstract)
{
// for null terminator
Expand All @@ -55,10 +41,10 @@ public UnixDomainSocketEndPoint(string path)

_path = path;
_encodedPath = new byte[bufferLength];
int bytesEncoded = s_pathEncoding.GetBytes(path, 0, path.Length, _encodedPath, 0);
int bytesEncoded = Encoding.UTF8.GetBytes(path, 0, path.Length, _encodedPath, 0);
Debug.Assert(bufferLength - (isAbstract ? 0 : 1) == bytesEncoded);

if (!s_udsSupported.Value)
if (!Socket.OSSupportsUnixDomainSockets)
{
throw new PlatformNotSupportedException();
}
Expand Down Expand Up @@ -95,7 +81,7 @@ internal UnixDomainSocketEndPoint(SocketAddress socketAddress)
length--;
}
}
_path = s_pathEncoding.GetString(_encodedPath, 0, length);
_path = Encoding.UTF8.GetString(_encodedPath, 0, length);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ await Assert.ThrowsAsync<InvalidOperationException>(() =>
[Fact]
public void SocketCtr_SocketInformation_NonIpSocket_ThrowsNotSupportedException()
{
// UDS unsupported:
if (!PlatformDetection.IsWindows10Version1803OrGreater || !Environment.Is64BitProcess) return;
if (!Socket.OSSupportsUnixDomainSockets) return;

using Socket original = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
SocketInformation info = original.DuplicateAndClose(CurrentProcessId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ namespace System.Net.Sockets.Tests
{
public class UnixDomainSocketTest
{
[Fact]
public void OSSupportsUnixDomainSockets_ReturnsCorrectValue()
{
Assert.Equal(PlatformSupportsUnixDomainSockets, Socket.OSSupportsUnixDomainSockets);
}

[PlatformSpecific(~TestPlatforms.Windows)] // Windows doesn't currently support ConnectEx with domain sockets
[ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_Success()
Expand Down Expand Up @@ -446,17 +452,13 @@ private static bool PlatformSupportsUnixDomainSockets
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// UDS support added in April 2018 Update
if (!PlatformDetection.IsWindows10Version1803OrGreater || PlatformDetection.IsWindowsNanoServer)
try
{
return false;
using var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Tcp);
}

// TODO: Windows 10 April 2018 Update doesn't support UDS in 32-bit processes on 64-bit OSes,
// allowing the socket call to succeed but then failing in bind. Remove this check once it's supported.
if (!Environment.Is64BitProcess && Environment.Is64BitOperatingSystem)
catch (SocketException se)
{
return false;
return se.SocketErrorCode != SocketError.AddressFamilyNotSupported;
}
}

Expand Down

0 comments on commit 03fc3ab

Please sign in to comment.