From feef08fc8ccfa4196a71e58c0cdc64e27f7e0fca Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Fri, 22 May 2020 12:07:07 +0200 Subject: [PATCH 1/6] Snocket.Accept Rescue Alex Vieth's 'Accept' modification. I couldn't cherry-pick the commit since it was burried inside a merge commit. There's no proper way to fix `Ouroboros.Network.Soocket.fromSnocket`, but this is ok, as it will be removed in a later commit. --- .../src/Ouroboros/Network/Snocket.hs | 103 +++++++++++++----- .../src/Ouroboros/Network/Socket.hs | 15 ++- .../test/Test/Ouroboros/Network/Socket.hs | 56 +++++----- 3 files changed, 117 insertions(+), 57 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index 34631f48b8a..c1ebaefbe29 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -8,6 +8,7 @@ module Ouroboros.Network.Snocket ( -- * Snocket Interface Accept (..) + , Accepted (..) , AddressFamily (..) , Snocket (..) -- ** Socket based Snocktes @@ -32,6 +33,7 @@ import Control.Monad (when) import Control.Monad.Class.MonadTime (DiffTime) import Control.Tracer (Tracer) import Data.Bifunctor (Bifunctor (..)) +import Data.Bifoldable (Bifoldable (..)) import Data.Hashable import GHC.Generics (Generic) import Quiet (Quiet (..)) @@ -96,13 +98,26 @@ import Ouroboros.Network.IOManager -- descriptor by `createNamedPipe`, see 'namedPipeSnocket'. -- newtype Accept m fd addr = Accept - { runAccept :: m (fd, addr, Accept m fd addr) + { runAccept :: m (Accepted fd addr, Accept m fd addr) } instance Functor m => Bifunctor (Accept m) where - bimap f g ac = Accept $ h <$> runAccept ac + bimap f g (Accept ac) = Accept (h <$> ac) where - h (fd, addr, next) = (f fd, g addr, bimap f g next) + h (accepted, next) = (bimap f g accepted, bimap f g next) + + +data Accepted fd addr where + AcceptFailure :: !SomeException -> Accepted fd addr + Accepted :: !fd -> !addr -> Accepted fd addr + +instance Bifunctor Accepted where + bimap f g (Accepted fd addr) = Accepted (f fd) (g addr) + bimap _ _ (AcceptFailure err) = AcceptFailure err + +instance Bifoldable Accepted where + bifoldMap f g (Accepted fd addr) = f fd <> g addr + bifoldMap _ _ (AcceptFailure _) = mempty -- | BSD accept loop. @@ -112,21 +127,35 @@ berkeleyAccept :: IOManager -> Accept IO Socket SockAddr berkeleyAccept ioManager sock = go where - go = Accept $ do - (sock', addr') <- + go = Accept (acceptOne `catch` handleException) + + acceptOne + :: IO ( Accepted Socket SockAddr + , Accept IO Socket SockAddr + ) + acceptOne = + bracketOnError #if !defined(mingw32_HOST_OS) - Socket.accept sock + (Socket.accept sock) #else - Win32.Async.accept sock + (Win32.Async.accept sock) #endif - associateWithIOManager ioManager (Right sock') - `catch` \(e :: IOException) -> do - Socket.close sock' - throwIO e - `catch` \(SomeAsyncException _) -> do - Socket.close sock' - throwIO e - return (sock', addr', go) + (Socket.close . fst) + $ \(sock', addr') -> do + associateWithIOManager ioManager (Right sock') + return (Accepted sock' addr', go) + + -- Only non-async exceptions will be caught and put into the + -- AcceptFailure variant. + handleException + :: SomeException + -> IO ( Accepted Socket SockAddr + , Accept IO Socket SockAddr + ) + handleException err = + case fromException err of + Just (SomeAsyncException _) -> throwIO err + Nothing -> pure (AcceptFailure err, go) -- | Local address, on Unix is associated with `Socket.AF_UNIX` family, on -- @@ -186,6 +215,9 @@ data Snocket m fd addr = Snocket { , bind :: fd -> addr -> m () , listen :: fd -> m () + -- SomeException is chosen here to avoid having to include it in the Snocket + -- type, and therefore refactoring a bunch of stuff. + -- FIXME probably a good idea to abstract it. , accept :: fd -> Accept m fd addr , close :: fd -> m () @@ -346,7 +378,7 @@ localSnocket ioManager path = Snocket { , accept = \sock@(LocalSocket hpipe) -> Accept $ do Win32.Async.connectNamedPipe hpipe - return (sock, localAddress, acceptNext) + return (Accepted sock localAddress, acceptNext) -- Win32.closeHandle is not interrupible , close = Win32.closeHandle . getLocalHandle @@ -358,19 +390,40 @@ localSnocket ioManager path = Snocket { localAddress = LocalAddress path acceptNext :: Accept IO LocalSocket LocalAddress - acceptNext = Accept $ do - hpipe <- Win32.createNamedPipe + acceptNext = go + where + go = Accept (acceptOne `catch` handleIOException) + + handleIOException + :: IOException + -> IO ( Accepted LocalSocket LocalAddress + , Accept IO LocalSocket LocalAddress + ) + handleIOException err = + pure ( AcceptFailure (toException err) + , go + ) + + acceptOne + :: IO ( Accepted LocalSocket LocalAddress + , Accept IO LocalSocket LocalAddress + ) + acceptOne = + bracketOnError + (Win32.createNamedPipe path (Win32.pIPE_ACCESS_DUPLEX .|. Win32.fILE_FLAG_OVERLAPPED) (Win32.pIPE_TYPE_BYTE .|. Win32.pIPE_READMODE_BYTE) Win32.pIPE_UNLIMITED_INSTANCES - 65536 -- outbound pipe size - 16384 -- inbound pipe size - 0 -- default timeout - Nothing -- default security - associateWithIOManager ioManager (Left hpipe) - Win32.Async.connectNamedPipe hpipe - return (LocalSocket hpipe, localAddress, acceptNext) + 65536 -- outbound pipe size + 16384 -- inbound pipe size + 0 -- default timeout + Nothing) -- default security + Win32.closeHandle + $ \hpipe -> do + associateWithIOManager ioManager (Left hpipe) + Win32.Async.connectNamedPipe hpipe + return (Accepted (LocalSocket hpipe) localAddress, go) -- local snocket on unix #else diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs index bcc73f0b6b2..677019cbf30 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs @@ -433,11 +433,16 @@ fromSnocket tblVar sn sd = go (Snocket.accept sn sd) where go :: Snocket.Accept IO fd addr -> Server.Socket addr fd go (Snocket.Accept accept) = Server.Socket $ do - (sd', remoteAddr, next) <- accept - -- TOOD: we don't need to that on each accept - localAddr <- Snocket.getLocalAddr sn sd' - atomically $ addConnection tblVar remoteAddr localAddr Nothing - pure (remoteAddr, sd', close remoteAddr localAddr sd', go next) + (result, next) <- accept + case result of + Snocket.Accepted sd' remoteAddr -> do + -- TOOD: we don't need to that on each accept + localAddr <- Snocket.getLocalAddr sn sd' + atomically $ addConnection tblVar remoteAddr localAddr Nothing + pure (remoteAddr, sd', close remoteAddr localAddr sd', go next) + Snocket.AcceptFailure err -> + -- the is no way to construct 'Server.Socket'; This will be removed in a later commit! + throwIO err close remoteAddr localAddr sd' = do removeConnection tblVar remoteAddr localAddr diff --git a/ouroboros-network-framework/test/Test/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/test/Test/Ouroboros/Network/Socket.hs index 17c6af852e9..24854ef62c0 100644 --- a/ouroboros-network-framework/test/Test/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/test/Test/Ouroboros/Network/Socket.hs @@ -13,6 +13,7 @@ module Test.Ouroboros.Network.Socket (tests) where import Data.Void (Void) import Data.List (mapAccumL) +import Data.Bifoldable (bitraverse_) import qualified Data.ByteString.Lazy as BL import Data.Proxy (Proxy (..)) import Data.Time.Clock (UTCTime, getCurrentTime) @@ -335,20 +336,22 @@ prop_socket_recv_error f rerr = -- accept a connection and start mux on it bracket (runAccept $ accept snocket sd) - (\(sd', _, _) -> Socket.close sd') - $ \(sd', _, _) -> do - remoteAddress <- Socket.getPeerName sd' - let timeout = if rerr == RecvSDUTimeout then 0.10 - else (-1) -- No timeout - bearer = Mx.socketAsMuxBearer timeout nullTracer sd' - connectionId = ConnectionId { - localAddress = Socket.addrAddress muxAddress, - remoteAddress - } - _ <- async $ do - threadDelay 0.1 - atomically $ putTMVar lock () - Mx.muxStart nullTracer (toApplication connectionId (continueForever (Proxy :: Proxy IO)) app) bearer + (bitraverse_ Socket.close pure . fst) + $ \(accepted, _acceptNext) -> case accepted of + AcceptFailure err -> throwIO err + Accepted sd' _ -> do + remoteAddress <- Socket.getPeerName sd' + let timeout = if rerr == RecvSDUTimeout then 0.10 + else (-1) -- No timeout + bearer = Mx.socketAsMuxBearer timeout nullTracer sd' + connectionId = ConnectionId { + localAddress = Socket.addrAddress muxAddress, + remoteAddress + } + _ <- async $ do + threadDelay 0.1 + atomically $ putTMVar lock () + Mx.muxStart nullTracer (toApplication connectionId (continueForever (Proxy :: Proxy IO)) app) bearer ) $ \muxAsync -> do @@ -414,22 +417,21 @@ prop_socket_send_error rerr = -- accept a connection and start mux on it bracket (runAccept $ accept snocket sd) - (\(sd', _, _) -> Socket.close sd') - (\(sd', _, _) -> - let sduTimeout = if rerr == SendSDUTimeout then 0.10 - else (-1) -- No timeout - bearer = Mx.socketAsMuxBearer sduTimeout nullTracer sd' - blob = BL.pack $ replicate 0xffff 0xa5 in - withTimeoutSerial $ \timeout -> - -- send maximum mux sdus until we've filled the window. - replicateM 100 $ do - ((), Nothing) <$ write bearer timeout (wrap blob ResponderDir (MiniProtocolNum 0)) - ) - + (bitraverse_ Socket.close pure . fst) + $ \(accepted, _acceptNext) -> case accepted of + AcceptFailure err -> throwIO err + Accepted sd' _ -> do + let sduTimeout = if rerr == SendSDUTimeout then 0.10 + else (-1) -- No timeout + bearer = Mx.socketAsMuxBearer sduTimeout nullTracer sd' + blob = BL.pack $ replicate 0xffff 0xa5 + withTimeoutSerial $ \timeout -> + -- send maximum mux sdus until we've filled the window. + replicateM 100 $ do + ((), Nothing) <$ write bearer timeout (wrap blob ResponderDir (MiniProtocolNum 0)) ) $ \muxAsync -> do - sd' <- openToConnect snocket addr -- connect to muxAddress _ <- connect snocket sd' addr From cbda9be3f4b870ecdf79c2b9d95f30d74920a666 Mon Sep 17 00:00:00 2001 From: Karl Knutsson Date: Fri, 11 Jun 2021 10:38:29 +0200 Subject: [PATCH 2/6] Use a counter to provide a unique remote addr The inbound governor requires all clients to have a unique address. UNIX sockets and windows named pipes lack this so we use a counter to generate remote addresses for local clients. --- .../src/Ouroboros/Network/Snocket.hs | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index c1ebaefbe29..4f5b662467e 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingVia #-} @@ -35,6 +36,7 @@ import Control.Tracer (Tracer) import Data.Bifunctor (Bifunctor (..)) import Data.Bifoldable (Bifoldable (..)) import Data.Hashable +import Data.Word import GHC.Generics (Generic) import Quiet (Quiet (..)) #if !defined(mingw32_HOST_OS) @@ -125,15 +127,16 @@ instance Bifoldable Accepted where berkeleyAccept :: IOManager -> Socket -> Accept IO Socket SockAddr -berkeleyAccept ioManager sock = go +berkeleyAccept ioManager sock = go 0 where - go = Accept (acceptOne `catch` handleException) + go cnt = Accept (acceptOne cnt `catch` handleException cnt) acceptOne - :: IO ( Accepted Socket SockAddr + :: Word64 + -> IO ( Accepted Socket SockAddr , Accept IO Socket SockAddr ) - acceptOne = + acceptOne !cnt = bracketOnError #if !defined(mingw32_HOST_OS) (Socket.accept sock) @@ -143,19 +146,33 @@ berkeleyAccept ioManager sock = go (Socket.close . fst) $ \(sock', addr') -> do associateWithIOManager ioManager (Right sock') - return (Accepted sock' addr', go) + + -- UNIX sockets don't provide a unique endpoint for the remote + -- side, but the InboundGovernor/Server requires one in order to + -- track connections. + -- So to differentiate clients we use a simple counter as the + -- remote end's address. + -- + addr'' <- case addr' of + Socket.SockAddrUnix _ -> + return $ Socket.SockAddrUnix $ + "temp-" ++ show cnt + _ -> return addr' + + return (Accepted sock' addr'', go $ succ cnt) -- Only non-async exceptions will be caught and put into the -- AcceptFailure variant. handleException - :: SomeException + :: Word64 + -> SomeException -> IO ( Accepted Socket SockAddr , Accept IO Socket SockAddr ) - handleException err = + handleException !cnt err = case fromException err of Just (SomeAsyncException _) -> throwIO err - Nothing -> pure (AcceptFailure err, go) + Nothing -> pure (AcceptFailure err, go cnt) -- | Local address, on Unix is associated with `Socket.AF_UNIX` family, on -- @@ -390,25 +407,27 @@ localSnocket ioManager path = Snocket { localAddress = LocalAddress path acceptNext :: Accept IO LocalSocket LocalAddress - acceptNext = go + acceptNext = go 0 where - go = Accept (acceptOne `catch` handleIOException) + go cnt = Accept (acceptOne cnt `catch` handleIOException cnt) handleIOException - :: IOException + :: Word64 + -> IOException -> IO ( Accepted LocalSocket LocalAddress , Accept IO LocalSocket LocalAddress ) - handleIOException err = + handleIOException !cnt err = pure ( AcceptFailure (toException err) - , go + , go cnt ) acceptOne - :: IO ( Accepted LocalSocket LocalAddress + :: Word64 + -> IO ( Accepted LocalSocket LocalAddress , Accept IO LocalSocket LocalAddress ) - acceptOne = + acceptOne !cnt = bracketOnError (Win32.createNamedPipe path @@ -423,7 +442,13 @@ localSnocket ioManager path = Snocket { $ \hpipe -> do associateWithIOManager ioManager (Left hpipe) Win32.Async.connectNamedPipe hpipe - return (Accepted (LocalSocket hpipe) localAddress, go) + -- InboundGovernor/Server requires a unique address for the + -- remote end one in order to track connections. + -- So to differentiate clients we use a simple counter as the + -- remote end's address. + -- + let addr = localAddressFromPath $ "temp-" ++ show cnt + return (Accepted (LocalSocket hpipe) addr, go $ succ cnt ) -- local snocket on unix #else From ab942f91df4150fa5209ab1371891914a4f9333b Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Thu, 16 Sep 2021 11:55:11 +0200 Subject: [PATCH 3/6] snockets - monadic toBearer Running monadic actions inside toBearer is useful when using file descriptors of a simulation environment. --- .../src/Ouroboros/Network/Snocket.hs | 13 +++++++++---- .../src/Ouroboros/Network/Socket.hs | 14 ++++++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index 4f5b662467e..798708b2918 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -239,10 +239,15 @@ data Snocket m fd addr = Snocket { , close :: fd -> m () - , toBearer :: DiffTime -> Tracer m MuxTrace -> fd -> MuxBearer m + , toBearer :: DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m) } +pureBearer :: Monad m + => (DiffTime -> Tracer m MuxTrace -> fd -> MuxBearer m) + -> DiffTime -> Tracer m MuxTrace -> fd -> m (MuxBearer m) +pureBearer f = \timeout tr fd -> return (f timeout tr fd) + -- -- Socket based Snockets -- @@ -306,7 +311,7 @@ socketSnocket ioManager = Snocket { -- should be fixed upstream, once that's done we can remove -- `unitnerruptibleMask_' , close = uninterruptibleMask_ . Socket.close - , toBearer = Mx.socketAsMuxBearer + , toBearer = pureBearer Mx.socketAsMuxBearer } where openSocket :: AddressFamily SockAddr -> IO Socket @@ -400,7 +405,7 @@ localSnocket ioManager path = Snocket { -- Win32.closeHandle is not interrupible , close = Win32.closeHandle . getLocalHandle - , toBearer = \_sduTimeout tr -> namedPipeAsBearer tr . getLocalHandle + , toBearer = \_sduTimeout tr -> pure . namedPipeAsBearer tr . getLocalHandle } where localAddress :: LocalAddress @@ -468,7 +473,7 @@ localSnocket ioManager _ = , open = openSocket , openToConnect = \_addr -> openSocket LocalFamily , close = uninterruptibleMask_ . Socket.close . getLocalHandle - , toBearer = \df tr (LocalSocket sd) -> Mx.socketAsMuxBearer df tr sd + , toBearer = \df tr (LocalSocket sd) -> pure (Mx.socketAsMuxBearer df tr sd) } where toLocalAddress :: SockAddr -> LocalAddress diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs index 677019cbf30..cf57b027ee1 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs @@ -250,9 +250,10 @@ connectToNode' sn handshakeCodec handshakeTimeLimits versionDataCodec NetworkCon muxTracer <- initDeltaQTracer' $ Mx.WithMuxBearer connectionId `contramap` nctMuxTracer ts_start <- getMonotonicTime + handshakeBearer <- Snocket.toBearer sn sduHandshakeTimeout muxTracer sd app_e <- runHandshakeClient - (Snocket.toBearer sn sduHandshakeTimeout muxTracer sd) + handshakeBearer connectionId -- TODO: push 'HandshakeArguments' up the call stack. HandshakeArguments { @@ -275,10 +276,11 @@ connectToNode' sn handshakeCodec handshakeTimeLimits versionDataCodec NetworkCon Right (app, _versionNumber, _agreedOptions) -> do traceWith muxTracer $ Mx.MuxTraceHandshakeClientEnd (diffTime ts_end ts_start) + bearer <- Snocket.toBearer sn sduTimeout muxTracer sd Mx.muxStart muxTracer (toApplication connectionId (continueForever (Proxy :: Proxy IO)) app) - (Snocket.toBearer sn sduTimeout muxTracer sd) + bearer -- Wraps a Socket inside a Snocket and calls connectToNode' @@ -374,9 +376,12 @@ beginConnection sn muxTracer handshakeTracer handshakeCodec handshakeTimeLimits traceWith muxTracer' $ Mx.MuxTraceHandshakeStart + handshakeBearer <- Snocket.toBearer sn + sduHandshakeTimeout + muxTracer' sd app_e <- runHandshakeServer - (Snocket.toBearer sn sduHandshakeTimeout muxTracer' sd) + handshakeBearer connectionId HandshakeArguments { haHandshakeTracer = handshakeTracer, @@ -398,10 +403,11 @@ beginConnection sn muxTracer handshakeTracer handshakeCodec handshakeTimeLimits Right (SomeResponderApplication app, _versionNumber, _agreedOptions) -> do traceWith muxTracer' $ Mx.MuxTraceHandshakeServerEnd + bearer <- Snocket.toBearer sn sduTimeout muxTracer' sd Mx.muxStart muxTracer' (toApplication connectionId (continueForever (Proxy :: Proxy IO)) app) - (Snocket.toBearer sn sduTimeout muxTracer' sd) + bearer RejectConnection st' _peerid -> pure $ Server.Reject st' From dc7b4d4c11873d2c0caf32f5f51ee897bc77d3af Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Thu, 16 Sep 2021 12:01:13 +0200 Subject: [PATCH 4/6] snocket: support getLocalName for named pipes (Windows) On Windows there's no way of getting path of the named pipe. Unfortunatelly, `GetFinalNameByHandle` does not support named pipes. This patch alters Snocket interface which allows to store the path within 'LocalSnocket'. For Window's named pipes `getLocalAddr` and `getRemoteAddr` return the same path. --- ouroboros-network-framework/demo/ping-pong.hs | 11 ++- .../src/Ouroboros/Network/Snocket.hs | 97 ++++++++++--------- ouroboros-network/demo/chain-sync.hs | 8 +- .../src/Ouroboros/Network/Diffusion.hs | 4 +- 4 files changed, 65 insertions(+), 55 deletions(-) diff --git a/ouroboros-network-framework/demo/ping-pong.hs b/ouroboros-network-framework/demo/ping-pong.hs index 7a246290f62..0fb9f549cfe 100644 --- a/ouroboros-network-framework/demo/ping-pong.hs +++ b/ouroboros-network-framework/demo/ping-pong.hs @@ -26,6 +26,7 @@ import System.Exit import Ouroboros.Network.Socket import Ouroboros.Network.Snocket +import qualified Ouroboros.Network.Snocket as Snocket import Ouroboros.Network.Mux import Ouroboros.Network.ErrorPolicy import Ouroboros.Network.IOManager @@ -107,7 +108,7 @@ clientPingPong :: Bool -> IO () clientPingPong pipelined = withIOManager $ \iomgr -> connectToNode - (localSnocket iomgr defaultLocalSocketAddrPath) + (Snocket.localSnocket iomgr) unversionedHandshakeCodec noTimeLimitsHandshake (cborTermVersionDataCodec unversionedProtocolDataCodec) @@ -145,7 +146,7 @@ serverPingPong = networkState <- newNetworkMutableState _ <- async $ cleanNetworkMutableState networkState withServerNode - (localSnocket iomgr defaultLocalSocketAddrPath) + (Snocket.localSnocket iomgr) nullNetworkServerTracers networkState (AcceptedConnectionsLimit maxBound maxBound 0) @@ -203,9 +204,9 @@ demoProtocol1 pingPong pingPong' = clientPingPong2 :: IO () clientPingPong2 = - withIOManager $ \iomgr -> + withIOManager $ \iomgr -> do connectToNode - (localSnocket iomgr defaultLocalSocketAddrPath) + (Snocket.localSnocket iomgr) unversionedHandshakeCodec noTimeLimitsHandshake (cborTermVersionDataCodec unversionedProtocolDataCodec) @@ -256,7 +257,7 @@ serverPingPong2 = networkState <- newNetworkMutableState _ <- async $ cleanNetworkMutableState networkState withServerNode - (localSnocket iomgr defaultLocalSocketAddrPath) + (Snocket.localSnocket iomgr) nullNetworkServerTracers networkState (AcceptedConnectionsLimit maxBound maxBound 0) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index 798708b2918..e802f37f00e 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -5,6 +5,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} module Ouroboros.Network.Snocket ( -- * Snocket Interface @@ -187,20 +188,19 @@ instance Hashable LocalAddress where -- | We support either sockets or named pipes. -- +-- 'LocalFamily' requires 'LocalAddress', this is needed to provide path of the +-- openned Win32 'HANDLE'. +-- data AddressFamily addr where SocketFamily :: !Socket.Family -> AddressFamily Socket.SockAddr - LocalFamily :: AddressFamily LocalAddress + LocalFamily :: !LocalAddress -> AddressFamily LocalAddress -instance Eq (AddressFamily addr) where - SocketFamily fam0 == SocketFamily fam1 = fam0 == fam1 - LocalFamily == LocalFamily = True +deriving instance Eq addr => Eq (AddressFamily addr) +deriving instance Show addr => Show (AddressFamily addr) -instance Show (AddressFamily addr) where - show (SocketFamily fam) = show fam - show LocalFamily = "LocalFamily" -- | Abstract communication interface that can be used by more than -- 'Socket'. Snockets are polymorphic over monad which is used, this feature @@ -213,7 +213,7 @@ data Snocket m fd addr = Snocket { , addrFamily :: addr -> AddressFamily addr -- | Open a file descriptor (socket / namedPipe). For named pipes this is - -- using 'CreateNamedPipe' syscall, for Berkeley sockets 'socket' is used.. + -- using 'CreateNamedPipe' syscall, for Berkeley sockets 'socket' is used. -- , open :: AddressFamily addr -> m fd @@ -223,7 +223,7 @@ data Snocket m fd addr = Snocket { -- For named pipes we need full 'addr' rather than just address family as -- it is for sockets. -- - , openToConnect :: addr -> m fd + , openToConnect :: addr -> m fd -- | `connect` is only needed for Berkeley sockets, for named pipes this is -- no-op. @@ -342,25 +342,41 @@ type LocalHandle = Socket #endif -- | System dependent LocalSnocket type +-- +#if defined(mingw32_HOST_OS) +data LocalSocket = LocalSocket { getLocalHandle :: LocalHandle + , getLocalPath :: LocalAddress + } + deriving (Eq, Generic) + deriving Show via Quiet LocalSocket +#else newtype LocalSocket = LocalSocket { getLocalHandle :: LocalHandle } deriving (Eq, Generic) deriving Show via Quiet LocalSocket +#endif -- | System dependent LocalSnocket type LocalSnocket = Snocket IO LocalSocket LocalAddress -localSnocket :: IOManager -> FilePath -> LocalSnocket + +-- | Create a 'LocalSnocket'. +-- +-- On /Windows/, there is no way to get path associated to a named pipe. To go +-- around this, the address passed to 'open' via 'LocalFamily' will be +-- referenced by 'LocalSocket'. +-- +localSnocket :: IOManager -> LocalSnocket #if defined(mingw32_HOST_OS) -localSnocket ioManager path = Snocket { - getLocalAddr = \_ -> return localAddress - , getRemoteAddr = \_ -> return localAddress - , addrFamily = \_ -> LocalFamily +localSnocket ioManager = Snocket { + getLocalAddr = return . getLocalPath + , getRemoteAddr = return . getLocalPath + , addrFamily = LocalFamily - , open = \_addrFamily -> do + , open = \(LocalFamily addr) -> do hpipe <- Win32.createNamedPipe - path + (getFilePath addr) (Win32.pIPE_ACCESS_DUPLEX .|. Win32.fILE_FLAG_OVERLAPPED) - (Win32.pIPE_TYPE_BYTE .|. Win32.pIPE_READMODE_BYTE) + (Win32.pIPE_TYPE_BYTE .|. Win32.pIPE_READMODE_BYTE) Win32.pIPE_UNLIMITED_INSTANCES 65536 -- outbound pipe size 16384 -- inbound pipe size @@ -373,7 +389,7 @@ localSnocket ioManager path = Snocket { `catch` \(SomeAsyncException _) -> do Win32.closeHandle hpipe throwIO e - pure (LocalSocket hpipe) + pure (LocalSocket hpipe addr) -- To connect, simply create a file whose name is the named pipe name. , openToConnect = \(LocalAddress pipeName) -> do @@ -391,16 +407,16 @@ localSnocket ioManager path = Snocket { `catch` \(SomeAsyncException _) -> do Win32.closeHandle hpipe throwIO e - return (LocalSocket hpipe) + return (LocalSocket hpipe (LocalAddress pipeName)) , connect = \_ _ -> pure () -- Bind and listen are no-op. , bind = \_ _ -> pure () , listen = \_ -> pure () - , accept = \sock@(LocalSocket hpipe) -> Accept $ do + , accept = \sock@(LocalSocket hpipe addr) -> Accept $ do Win32.Async.connectNamedPipe hpipe - return (Accepted sock localAddress, acceptNext) + return (Accepted sock addr, acceptNext 0 addr) -- Win32.closeHandle is not interrupible , close = Win32.closeHandle . getLocalHandle @@ -408,36 +424,29 @@ localSnocket ioManager path = Snocket { , toBearer = \_sduTimeout tr -> pure . namedPipeAsBearer tr . getLocalHandle } where - localAddress :: LocalAddress - localAddress = LocalAddress path - - acceptNext :: Accept IO LocalSocket LocalAddress - acceptNext = go 0 + acceptNext :: Word64 -> LocalAddress -> Accept IO LocalSocket LocalAddress + acceptNext !cnt addr = Accept (acceptOne `catch` handleIOException) where - go cnt = Accept (acceptOne cnt `catch` handleIOException cnt) - handleIOException - :: Word64 - -> IOException + :: IOException -> IO ( Accepted LocalSocket LocalAddress , Accept IO LocalSocket LocalAddress ) - handleIOException !cnt err = + handleIOException err = pure ( AcceptFailure (toException err) - , go cnt + , acceptNext (succ cnt) addr ) acceptOne - :: Word64 - -> IO ( Accepted LocalSocket LocalAddress + :: IO ( Accepted LocalSocket LocalAddress , Accept IO LocalSocket LocalAddress ) - acceptOne !cnt = + acceptOne = bracketOnError (Win32.createNamedPipe - path + (getFilePath addr) (Win32.pIPE_ACCESS_DUPLEX .|. Win32.fILE_FLAG_OVERLAPPED) - (Win32.pIPE_TYPE_BYTE .|. Win32.pIPE_READMODE_BYTE) + (Win32.pIPE_TYPE_BYTE .|. Win32.pIPE_READMODE_BYTE) Win32.pIPE_UNLIMITED_INSTANCES 65536 -- outbound pipe size 16384 -- inbound pipe size @@ -452,17 +461,17 @@ localSnocket ioManager path = Snocket { -- So to differentiate clients we use a simple counter as the -- remote end's address. -- - let addr = localAddressFromPath $ "temp-" ++ show cnt - return (Accepted (LocalSocket hpipe) addr, go $ succ cnt ) + let addr' = LocalAddress $ "\\\\.\\pipe\\ouroboros-network-temp-" ++ show cnt + return (Accepted (LocalSocket hpipe addr') addr', acceptNext (succ cnt) addr) -- local snocket on unix #else -localSnocket ioManager _ = +localSnocket ioManager = Snocket { getLocalAddr = fmap toLocalAddress . Socket.getSocketName . getLocalHandle , getRemoteAddr = fmap toLocalAddress . Socket.getPeerName . getLocalHandle - , addrFamily = const LocalFamily + , addrFamily = LocalFamily , connect = \(LocalSocket s) addr -> Socket.connect s (fromLocalAddress addr) , bind = \(LocalSocket fd) addr -> Socket.bind fd (fromLocalAddress addr) @@ -471,7 +480,7 @@ localSnocket ioManager _ = . berkeleyAccept ioManager . getLocalHandle , open = openSocket - , openToConnect = \_addr -> openSocket LocalFamily + , openToConnect = \addr -> openSocket (LocalFamily addr) , close = uninterruptibleMask_ . Socket.close . getLocalHandle , toBearer = \df tr (LocalSocket sd) -> pure (Mx.socketAsMuxBearer df tr sd) } @@ -484,7 +493,7 @@ localSnocket ioManager _ = fromLocalAddress = SockAddrUnix . getFilePath openSocket :: AddressFamily LocalAddress -> IO LocalSocket - openSocket LocalFamily = do + openSocket (LocalFamily _addr) = do sd <- Socket.socket AF_UNIX Socket.Stream Socket.defaultProtocol associateWithIOManager ioManager (Right sd) -- open is designed to be used in `bracket`, and thus it's called with @@ -518,7 +527,7 @@ socketFileDescriptor = fmap (FileDescriptor . fromIntegral) . Socket.unsafeFdSoc localSocketFileDescriptor :: LocalSocket -> IO FileDescriptor #if defined(mingw32_HOST_OS) localSocketFileDescriptor = - \(LocalSocket fd) -> case ptrToIntPtr fd of + \(LocalSocket fd _) -> case ptrToIntPtr fd of IntPtr i -> return (FileDescriptor i) #else localSocketFileDescriptor = socketFileDescriptor . getLocalHandle diff --git a/ouroboros-network/demo/chain-sync.hs b/ouroboros-network/demo/chain-sync.hs index 07344bd3a7b..3646fecfad1 100644 --- a/ouroboros-network/demo/chain-sync.hs +++ b/ouroboros-network/demo/chain-sync.hs @@ -152,7 +152,7 @@ clientChainSync sockPaths = withIOManager $ \iocp -> forConcurrently_ (zip [0..] sockPaths) $ \(index, sockPath) -> do threadDelay (50000 * index) connectToNode - (localSnocket iocp sockPath) + (localSnocket iocp) unversionedHandshakeCodec noTimeLimitsHandshake (cborTermVersionDataCodec unversionedProtocolDataCodec) @@ -182,7 +182,7 @@ serverChainSync sockAddr = withIOManager $ \iocp -> do networkState <- newNetworkMutableState _ <- async $ cleanNetworkMutableState networkState withServerNode - (localSnocket iocp defaultLocalSocketAddrPath) + (localSnocket iocp) nullNetworkServerTracers networkState (AcceptedConnectionsLimit maxBound maxBound 0) @@ -365,7 +365,7 @@ clientBlockFetch sockAddrs = withIOManager $ \iocp -> do peerAsyncs <- sequence [ async $ connectToNode - (localSnocket iocp defaultLocalSocketAddrPath) + (localSnocket iocp) unversionedHandshakeCodec noTimeLimitsHandshake (cborTermVersionDataCodec unversionedProtocolDataCodec) @@ -417,7 +417,7 @@ serverBlockFetch sockAddr = withIOManager $ \iocp -> do networkState <- newNetworkMutableState _ <- async $ cleanNetworkMutableState networkState withServerNode - (localSnocket iocp defaultLocalSocketAddrPath) + (localSnocket iocp) nullNetworkServerTracers networkState (AcceptedConnectionsLimit maxBound maxBound 0) diff --git a/ouroboros-network/src/Ouroboros/Network/Diffusion.hs b/ouroboros-network/src/Ouroboros/Network/Diffusion.hs index 64e6c68febf..81954418f1c 100644 --- a/ouroboros-network/src/Ouroboros/Network/Diffusion.hs +++ b/ouroboros-network/src/Ouroboros/Network/Diffusion.hs @@ -349,13 +349,13 @@ runDataDiffusion tracers case a of (Socket.SockAddrUnix path) -> do traceWith dtDiffusionInitializationTracer $ UsingSystemdSocket path - return (LocalSocket sd, Snocket.localSnocket iocp path) + return (LocalSocket sd, Snocket.localSnocket iocp) unsupportedAddr -> do traceWith dtDiffusionInitializationTracer $ UnsupportedLocalSystemdSocket unsupportedAddr throwIO UnsupportedLocalSocketType #endif Right addr -> do - let sn = Snocket.localSnocket iocp addr + let sn = Snocket.localSnocket iocp traceWith dtDiffusionInitializationTracer $ CreateSystemdSocketForSnocketPath addr sd <- Snocket.open sn (Snocket.addrFamily sn $ Snocket.localAddressFromPath addr) traceWith dtDiffusionInitializationTracer $ CreatedLocalSocket addr From f1d68a306ff082c7b8b5123f31f1fab640a9e2d0 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Thu, 16 Sep 2021 16:02:53 +0200 Subject: [PATCH 5/6] snocket: store local and remote path --- .../src/Ouroboros/Network/Snocket.hs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index e802f37f00e..a32c1ff2a26 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -345,7 +345,13 @@ type LocalHandle = Socket -- #if defined(mingw32_HOST_OS) data LocalSocket = LocalSocket { getLocalHandle :: LocalHandle + -- ^ underlying windows 'HANDLE' , getLocalPath :: LocalAddress + -- ^ original path, used when creating the handle + , getRemotePath :: LocalAddress + -- ^ unique identifier (not a real path). It + -- makes the pair of local and remote + -- addresses unique. } deriving (Eq, Generic) deriving Show via Quiet LocalSocket @@ -369,7 +375,7 @@ localSnocket :: IOManager -> LocalSnocket #if defined(mingw32_HOST_OS) localSnocket ioManager = Snocket { getLocalAddr = return . getLocalPath - , getRemoteAddr = return . getLocalPath + , getRemoteAddr = return . getRemotePath , addrFamily = LocalFamily , open = \(LocalFamily addr) -> do @@ -389,7 +395,7 @@ localSnocket ioManager = Snocket { `catch` \(SomeAsyncException _) -> do Win32.closeHandle hpipe throwIO e - pure (LocalSocket hpipe addr) + pure (LocalSocket hpipe addr (LocalAddress "")) -- To connect, simply create a file whose name is the named pipe name. , openToConnect = \(LocalAddress pipeName) -> do @@ -407,14 +413,14 @@ localSnocket ioManager = Snocket { `catch` \(SomeAsyncException _) -> do Win32.closeHandle hpipe throwIO e - return (LocalSocket hpipe (LocalAddress pipeName)) + return (LocalSocket hpipe (LocalAddress pipeName) (LocalAddress pipeName)) , connect = \_ _ -> pure () -- Bind and listen are no-op. , bind = \_ _ -> pure () , listen = \_ -> pure () - , accept = \sock@(LocalSocket hpipe addr) -> Accept $ do + , accept = \sock@(LocalSocket hpipe addr _) -> Accept $ do Win32.Async.connectNamedPipe hpipe return (Accepted sock addr, acceptNext 0 addr) @@ -462,7 +468,7 @@ localSnocket ioManager = Snocket { -- remote end's address. -- let addr' = LocalAddress $ "\\\\.\\pipe\\ouroboros-network-temp-" ++ show cnt - return (Accepted (LocalSocket hpipe addr') addr', acceptNext (succ cnt) addr) + return (Accepted (LocalSocket hpipe addr addr') addr', acceptNext (succ cnt) addr) -- local snocket on unix #else @@ -527,7 +533,7 @@ socketFileDescriptor = fmap (FileDescriptor . fromIntegral) . Socket.unsafeFdSoc localSocketFileDescriptor :: LocalSocket -> IO FileDescriptor #if defined(mingw32_HOST_OS) localSocketFileDescriptor = - \(LocalSocket fd _) -> case ptrToIntPtr fd of + \(LocalSocket fd _ _) -> case ptrToIntPtr fd of IntPtr i -> return (FileDescriptor i) #else localSocketFileDescriptor = socketFileDescriptor . getLocalHandle From e7efa148f2c5428c5b7ead763fdc10434a035307 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Thu, 16 Sep 2021 16:03:30 +0200 Subject: [PATCH 6/6] snocket: make Window's LocalSocket strict This will match the semantics of LocalSocket on other platforms. --- .../src/Ouroboros/Network/Snocket.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index a32c1ff2a26..450030d3e39 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -344,11 +344,11 @@ type LocalHandle = Socket -- | System dependent LocalSnocket type -- #if defined(mingw32_HOST_OS) -data LocalSocket = LocalSocket { getLocalHandle :: LocalHandle +data LocalSocket = LocalSocket { getLocalHandle :: !LocalHandle -- ^ underlying windows 'HANDLE' - , getLocalPath :: LocalAddress + , getLocalPath :: !LocalAddress -- ^ original path, used when creating the handle - , getRemotePath :: LocalAddress + , getRemotePath :: !LocalAddress -- ^ unique identifier (not a real path). It -- makes the pair of local and remote -- addresses unique.