Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review of pruning #3495

Merged
merged 9 commits into from
Nov 22, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
-- Undecidable instances are need for 'Show' instance of 'ConnectionState'.
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE QuantifiedConstraints #-}
Expand All @@ -25,7 +26,7 @@ module Ouroboros.Network.ConnectionManager.Core
) where

import Control.Exception (assert)
import Control.Monad (when)
import Control.Monad (forM_, guard, when)
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow hiding (handle)
Expand Down Expand Up @@ -395,6 +396,22 @@ getConnType (TerminatingState _connId _connThread _handleError) = Nothing
getConnType TerminatedState {} = Nothing


-- | Return 'True' if a connection is inbound. This must agree with
-- 'connectionStateToCounters'. Both are used for prunning.
--
isInboundConn :: ConnectionState peerAddr handle handleError version m -> Bool
isInboundConn ReservedOutboundState = False
isInboundConn (UnnegotiatedState pr _connId _connThread) = pr == Inbound
isInboundConn OutboundUniState {} = False
isInboundConn OutboundDupState {} = False
isInboundConn OutboundIdleState {} = False
isInboundConn InboundIdleState {} = True
isInboundConn InboundState {} = True
isInboundConn DuplexState {} = True
isInboundConn TerminatingState {} = False
isInboundConn TerminatedState {} = False


abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) -> AbstractState
abstractState = \s -> case s of
Unknown -> UnknownConnectionSt
Expand Down Expand Up @@ -429,6 +446,8 @@ defaultResetTimeout :: DiffTime
defaultResetTimeout = 5


newtype PruneAction m = PruneAction { runPruneAction :: m () }

-- | Instruction used internally in @unregisterOutboundConnectionImpl@, e.g. in
-- the implementation of one of the two @DemotedToCold^{dataFlow}_{Local}@
-- transitions.
Expand Down Expand Up @@ -464,23 +483,25 @@ data DemoteToColdLocal peerAddr handlerTrace handle handleError version m

-- | Duplex connection was demoted, prune connections.
--
| PruneConnections (ConnectionId peerAddr)
(Map peerAddr (Async m ()))
-- Left case is for when pruning tries to prune
-- the connection which triggered pruning, in this
-- case we do not want to trace a new transition.
--
-- Right case is for when the connection which
-- triggered pruning isn't pruned. In this case
-- we do want to trace a new transition.
!(Either
(ConnectionState
peerAddr handle
handleError version m)
(Transition (ConnectionState
peerAddr handle
handleError version m))
)
| PruneConnections (PruneAction m)
-- ^ prune action

!(Either
(ConnectionState
peerAddr handle
handleError version m)
(Transition (ConnectionState
peerAddr handle
handleError version m))
)
-- ^ Left case is for when pruning tries to prune
-- the connection which triggered pruning, in this
-- case we do not want to trace a new transition.
--
-- Right case is for when the connection which
-- triggered pruning isn't pruned. In this case
-- we do want to trace a new transition.


-- | Demote error.
| DemoteToColdLocalError (ConnectionManagerTrace peerAddr handlerTrace)
Expand Down Expand Up @@ -857,6 +878,62 @@ withConnectionManager ConnectionManagerArguments {
traverse_ (traceWith trTracer . TransitionTrace peerAddr) trs
traceCounters stateVar

-- Pruning is done in two stages:
-- * an STM transaction which selects which connections to prune, and sets
-- their state to 'TerminatedState';
-- * an io action which logs and cancells all the connection handler
-- threads.
mkPruneAction :: peerAddr
-> Int
-- ^ number of connections to prune
-> ConnectionManagerState peerAddr handle handleError version m
-> ConnectionState peerAddr handle handleError version m
-- ^ next connection state, if it will not be pruned.
-> StrictTVar m (ConnectionState peerAddr handle handleError version m)
-> Async m ()
-> STM m (Bool, PruneAction m)
-- ^ return if the connection was choose to be prunned and the
-- 'PruneAction'
mkPruneAction peerAddr numberToPrune state connState' connVar connThread = do
(choiceMap' :: Map peerAddr ( ConnectionType
, Async m ()
, StrictTVar m
(ConnectionState
peerAddr
handle handleError
version m)
))
<- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } ->
(\cs -> do
-- this expression returns @Maybe (connType, connThread)@;
-- 'traverseMaybeWithKey' collects all 'Just' cases.
guard (isInboundConn cs)
(,,connVar') <$> getConnType cs
<*> getConnThread cs)
<$> readTVar connVar'
let choiceMap =
case getConnType connState' of
Nothing -> assert False choiceMap'
Just a -> Map.insert peerAddr (a, connThread, connVar)
choiceMap'

pruneSet <-
cmPrunePolicy
((\(a,_,_) -> a) <$> choiceMap)
numberToPrune

let pruneMap = choiceMap `Map.restrictKeys` pruneSet
forM_ pruneMap $ \(_, _, connVar') ->
writeTVar connVar' (TerminatedState Nothing)

return ( peerAddr `Set.member` pruneSet
, PruneAction $ do
traceWith tracer (TrPruneConnections (Map.keysSet pruneMap)
numberToPrune
(Map.keysSet choiceMap))
forM_ pruneMap $ \(_, connThread', _) -> cancel connThread'
)

includeInboundConnectionImpl
:: HasCallStack
=> FreshIdSupply m
Expand Down Expand Up @@ -1842,36 +1919,19 @@ withConnectionManager ConnectionManagerArguments {

-- use 'numberOfConns + 1' because we want to know if we
-- actually let this connection evolve if we need to make
-- room for them by pruning.
-- room for them by pruning. This is because
-- 'countIncomingConnections' does not count 'OutboundDupState'
-- as an inbound connection, but does so for 'InboundIdleState'.
let numberToPrune =
numberOfConns + 1
- fromIntegral
(acceptedConnectionsHardLimit cmConnectionsLimits)
if numberToPrune > 0
then do
-- traverse the state and get only the connection which
-- have 'ConnectionType' and are running (have a thread).
-- This excludes connections in 'ReservedOutboundState',
-- 'TerminatingState' and 'TerminatedState'.
(choiseMap :: Map peerAddr (ConnectionType, Async m ()))
<- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } ->
(\cs -> -- this expression returns @Maybe (connType, connThread)@;
-- 'traverseMaybeWithKey' collects all 'Just' cases.
(,) <$> getConnType cs
<*> getConnThread cs)
<$> readTVar connVar'

pruneSet <-
cmPrunePolicy
(fst <$> choiseMap)
numberToPrune

when (remoteAddress connId `Set.notMember` pruneSet)
$ writeTVar connVar connState'
(_, prune)
<- mkPruneAction peerAddr numberToPrune state connState' connVar connThread
return
( PruneConnections connId
(snd <$> choiseMap `Map.restrictKeys` pruneSet)
(Left connState)
( PruneConnections prune (Left connState)
, Nothing
)

Expand Down Expand Up @@ -1925,61 +1985,16 @@ withConnectionManager ConnectionManagerArguments {
let connState' = InboundState connId connThread handle Duplex
tr = mkTransition connState connState'

numberOfConns <- countIncomingConnections state
let numberToPrune =
numberOfConns
- fromIntegral
(acceptedConnectionsHardLimit cmConnectionsLimits)

if numberToPrune > 0
then do
-- traverse the state and get only the connection which
-- have 'ConnectionType' and are running (have a thread).
-- This excludes connections in 'ReservedOutboundState',
-- 'TerminatingState' and 'TerminatedState'.
(choiseMap :: Map peerAddr (ConnectionType, Async m ()))
<- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } ->
(\cs -> -- this expression returns @Maybe (connType, connThread)@;
-- 'traverseMaybeWithKey' collects all 'Just' cases.
(,) <$> getConnType cs
<*> getConnThread cs)
<$> readTVar connVar'

pruneSet <-
cmPrunePolicy
(fst <$> choiseMap)
numberToPrune

-- If this connection is in the to-prune set we do not let it
-- evolve to a new state. Otherwise we do.
if Set.member peerAddr pruneSet
then
return
( PruneConnections connId
(snd <$> choiseMap `Map.restrictKeys` pruneSet)
(Left connState)
, Nothing
)
else do
writeTVar connVar connState'
return
( PruneConnections connId
(snd <$> choiseMap `Map.restrictKeys` pruneSet)
(Right tr)
, Nothing
)

else do
-- @
-- DemotedToCold^{Duplex}_{Local} : DuplexState
-- → InboundState Duplex
-- @
-- does not require to perform any additional io action (we
-- already updated 'connVar').
writeTVar connVar connState'
return ( DemoteToColdLocalNoop (Just tr) st
, Nothing
)
-- @
-- DemotedToCold^{Duplex}_{Local} : DuplexState
-- → InboundState Duplex
-- @
-- does not require to perform any additional io action (we
-- already updated 'connVar').
writeTVar connVar connState'
return ( DemoteToColdLocalNoop (Just tr) st
, Nothing
)

TerminatingState _connId _connThread _handleError ->
return (DemoteToColdLocalNoop Nothing st
Expand Down Expand Up @@ -2032,13 +2047,10 @@ withConnectionManager ConnectionManagerArguments {
Left connState ->
return (UnsupportedState (abstractState $ Known connState))

PruneConnections _connId pruneMap eTr -> do
PruneConnections prune eTr -> do
traverse_ (traceWith trTracer . TransitionTrace peerAddr) eTr
runPruneAction prune
traceCounters stateVar
traceWith tracer (TrPruneConnections (Map.keys pruneMap))
-- previous comment applies here as well.
traverse_ cancel pruneMap

return (OperationSuccess (abstractState (either Known fromState eTr)))

DemoteToColdLocalError trace st -> do
Expand Down Expand Up @@ -2127,31 +2139,15 @@ withConnectionManager ConnectionManagerArguments {
-- Are we above the hard limit?
if numberToPrune > 0
then do
-- traverse the state and get only the connection which
-- have 'ConnectionType' and are running (have a thread).
-- This excludes connections in 'ReservedOutboundState',
-- 'TerminatingState' and 'TerminatedState'.
(choiseMap :: Map peerAddr (ConnectionType, Async m ()))
<- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } ->
(\cs -> -- this expression returns @Maybe (connType, connThread)@;
-- 'traverseMaybeWithKey' collects all 'Just' cases.
(,) <$> getConnType cs
<*> getConnThread cs)
<$> readTVar connVar'

pruneSet <-
cmPrunePolicy
(fst <$> choiseMap)
numberToPrune

when (remoteAddress connId `Set.notMember` pruneSet)
(pruneSelf, prune)
<- mkPruneAction peerAddr numberToPrune state connState' connVar connThread

when (not pruneSelf)
$ writeTVar connVar connState'

return
( OperationSuccess tr
, Just ( snd <$> choiseMap `Map.restrictKeys` pruneSet
, Nothing
)

, Just prune
, Nothing
)

Expand Down Expand Up @@ -2182,30 +2178,14 @@ withConnectionManager ConnectionManagerArguments {
-- Are we above the hard limit?
if numberToPrune > 0
then do
-- traverse the state and get only the connection which
-- have 'ConnectionType' and are running (have a thread).
-- This excludes connections in 'ReservedOutboundState',
-- 'TerminatingState' and 'TerminatedState'.
(choiseMap :: Map peerAddr (ConnectionType, Async m ()))
<- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } ->
(\cs -> -- this expression returns @Maybe (connType, connThread)@;
-- 'traverseMaybeWithKey' collects all 'Just' cases.
(,) <$> getConnType cs
<*> getConnThread cs)
<$> readTVar connVar'

pruneSet <-
cmPrunePolicy
(fst <$> choiseMap)
numberToPrune

when (remoteAddress connId `Set.notMember` pruneSet)
$ writeTVar connVar connState'
(pruneSelf, prune)
<- mkPruneAction peerAddr numberToPrune state connState' connVar connThread
when (not pruneSelf)
$ writeTVar connVar connState'

return
( OperationSuccess tr
, Just ( snd <$> choiseMap `Map.restrictKeys` pruneSet
, Nothing
)
( OperationSuccess (mkTransition connState (TerminatedState Nothing))
, Just prune
, Nothing
)

Expand Down Expand Up @@ -2268,17 +2248,11 @@ withConnectionManager ConnectionManagerArguments {
traceWith trTracer (TransitionTrace peerAddr tr)
traceCounters stateVar

(OperationSuccess _, Just (pruneMap, mbTr)) -> do
traceWith tracer (TrPruneConnections (Map.keys pruneMap))
traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTr
(OperationSuccess tr, Just prune) -> do
traceWith trTracer (TransitionTrace peerAddr tr)
runPruneAction prune
traceCounters stateVar

-- We relay on the `finally` handler of connection thread to:
--
-- - close the socket,
-- - set the state to 'TerminatedState'
traverse_ cancel pruneMap

_ -> return ()
return (abstractState . fromState <$> result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,9 @@ data ConnectionManagerTrace peerAddr handlerTrace
| TrConnectionFailure !(ConnectionId peerAddr)
| TrConnectionNotFound !Provenance !peerAddr
| TrForbiddenOperation !peerAddr !AbstractState
| TrPruneConnections ![peerAddr]
| TrPruneConnections !(Set peerAddr) -- ^ prunning set
!Int -- ^ number connections that must be prunned
!(Set peerAddr) -- ^ choice set
| TrConnectionCleanup !(ConnectionId peerAddr)
| TrConnectionTimeWait !(ConnectionId peerAddr)
| TrConnectionTimeWaitDone !(ConnectionId peerAddr)
Expand Down
Loading