Skip to content

Commit

Permalink
trace-forward: remove potentially leaky continuation passing - datapo…
Browse files Browse the repository at this point in the history
…ints
  • Loading branch information
mgmeier authored and Jimbo4350 committed Sep 11, 2024
1 parent d6cfee7 commit db42685
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 41 deletions.
34 changes: 18 additions & 16 deletions trace-forward/src/Trace/Forward/Protocol/DataPoint/Forwarder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ data DataPointForwarder m a = DataPointForwarder
{ -- | The acceptor sent us a request for new 'DataPoint's.
recvMsgDataPointsRequest
:: [DataPointName]
-> m (DataPointValues, DataPointForwarder m a)
-> m DataPointValues

-- | The acceptor terminated. Here we have a pure return value, but we
-- could have done another action in 'm' if we wanted to.
Expand All @@ -30,19 +30,21 @@ dataPointForwarderPeer
:: Monad m
=> DataPointForwarder m a
-> Peer DataPointForward 'AsServer 'StIdle m a
dataPointForwarderPeer DataPointForwarder{recvMsgDataPointsRequest, recvMsgDone} =
-- In the 'StIdle' state the forwarder is awaiting a request message
-- from the acceptor.
Await (ClientAgency TokIdle) $ \case
-- The acceptor sent us a request for new 'DataPoint's, so now we're
-- in the 'StBusy' state which means it's the forwarder's turn to send
-- a reply.
MsgDataPointsRequest request -> Effect $ do
(reply, next) <- recvMsgDataPointsRequest request
return $ Yield (ServerAgency TokBusy)
(MsgDataPointsReply reply)
(dataPointForwarderPeer next)
dataPointForwarderPeer DataPointForwarder{recvMsgDataPointsRequest, recvMsgDone} = go
where
go =
-- In the 'StIdle' state the forwarder is awaiting a request message
-- from the acceptor.
Await (ClientAgency TokIdle) $ \case
-- The acceptor sent us a request for new 'DataPoint's, so now we're
-- in the 'StBusy' state which means it's the forwarder's turn to send
-- a reply.
MsgDataPointsRequest request -> Effect $ do
reply <- recvMsgDataPointsRequest request
return $ Yield (ServerAgency TokBusy)
(MsgDataPointsReply reply)
go

-- The acceptor sent the done transition, so we're in the 'StDone' state
-- so all we can do is stop using 'done', with a return value.
MsgDone -> Effect $ Done TokDone <$> recvMsgDone
-- The acceptor sent the done transition, so we're in the 'StDone' state
-- so all we can do is stop using 'done', with a return value.
MsgDone -> Effect $ Done TokDone <$> recvMsgDone
3 changes: 1 addition & 2 deletions trace-forward/src/Trace/Forward/Utils/DataPoint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ readFromStore dpStore =
DataPointForwarder
{ recvMsgDataPointsRequest = \dpNames -> do
store <- readTVarIO dpStore
let replyList = map (lookupDataPoint store) dpNames
return (replyList, readFromStore dpStore)
return $ map (lookupDataPoint store) dpNames
, recvMsgDone = return ()
}
where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ direct :: Monad m
direct DataPointForwarder { recvMsgDone }
(SendMsgDone mdone) =
(,) <$> recvMsgDone <*> mdone
direct DataPointForwarder { recvMsgDataPointsRequest }
direct server@DataPointForwarder { recvMsgDataPointsRequest }
(SendMsgDataPointsRequest (dpNames :: [DataPointName]) mclient) = do
(reply, server) <- recvMsgDataPointsRequest dpNames
reply <- recvMsgDataPointsRequest dpNames
client <- mclient reply
direct server client
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ module Test.Trace.Forward.Protocol.DataPoint.Examples
, dataPointForwarderCount
) where

import Control.Concurrent.Class.MonadSTM.TVar
import Control.Monad.Class.MonadSTM

import Trace.Forward.Protocol.DataPoint.Acceptor
import Trace.Forward.Protocol.DataPoint.Forwarder
import Trace.Forward.Protocol.DataPoint.Type
Expand All @@ -30,18 +33,17 @@ dataPointAcceptorApply f = go
$ \(_reply :: DataPointValues) -> return $ go (f acc) (pred n)

-- | A server which counts number received of 'MsgDataPointsRequest'.
--
dataPointForwarderCount
:: forall m. Monad m
=> DataPointForwarder m Int
dataPointForwarderCount = go 0
where
go n =
:: MonadSTM m
=> m (DataPointForwarder m Int)
dataPointForwarderCount = do
n <- newTVarIO 0
return $
DataPointForwarder
{ recvMsgDone = return n
{ recvMsgDone = readTVarIO n
, recvMsgDataPointsRequest =
\(dpNames :: [DataPointName]) ->
\(dpNames :: [DataPointName]) -> do
atomically $ modifyTVar' n succ
return ( zip dpNames (repeat Nothing)
, go (succ n)
)
}
40 changes: 28 additions & 12 deletions trace-forward/test/Test/Trace/Forward/Protocol/DataPoint/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Ouroboros.Network.Driver.Simple (runConnectedPeers)
import qualified Codec.Serialise as CBOR
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadST
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Monad.IOSim (runSimOrThrow)
import Control.Monad.ST (runST)
Expand Down Expand Up @@ -77,23 +78,38 @@ prop_direct_DataPointForward
-> NonNegative Int
-> Property
prop_direct_DataPointForward f (NonNegative n) =
runSimOrThrow
(direct
dataPointForwarderCount
(dataPointAcceptorApply f 0 n))
===
(n, foldr ($) 0 (replicate n f))
runSimOrThrow (prop_direct f n)

prop_direct
:: MonadSTM m
=> (Int -> Int)
-> Int
-> m Property
prop_direct f n = do
fwcount <- dataPointForwarderCount
result <- direct fwcount (dataPointAcceptorApply f 0 n)
return $ result === (n, foldr ($) 0 (replicate n f))

prop_connect_DataPointForward
:: (Int -> Int)
-> NonNegative Int
-> Bool
prop_connect_DataPointForward f (NonNegative n) =
case runSimOrThrow
(connect
(dataPointForwarderPeer dataPointForwarderCount)
(dataPointAcceptorPeer $ dataPointAcceptorApply f 0 n)) of
(s, c, TerminalStates TokDone TokDone) -> (s, c) == (n, foldr ($) 0 (replicate n f))
runSimOrThrow (prop_connect f n)

prop_connect
:: ( MonadST m
, MonadAsync m
)
=> (Int -> Int)
-> Int
-> m Bool
prop_connect f n = do
forwarder <- dataPointForwarderPeer <$> dataPointForwarderCount
result <- connect forwarder (dataPointAcceptorPeer $ dataPointAcceptorApply f 0 n)
case result of
(s, c, TerminalStates TokDone TokDone) ->
pure $ (s, c) == (n, foldr ($) 0 (replicate n f))

prop_channel
:: ( MonadST m
Expand All @@ -104,14 +120,14 @@ prop_channel
-> Int
-> m Property
prop_channel f n = do
forwarder <- dataPointForwarderPeer <$> dataPointForwarderCount
(s, c) <- runConnectedPeers createConnectedChannels
nullTracer
(codecDataPointForward CBOR.encode CBOR.decode
CBOR.encode CBOR.decode)
forwarder acceptor
return $ (s, c) === (n, foldr ($) 0 (replicate n f))
where
forwarder = dataPointForwarderPeer dataPointForwarderCount
acceptor = dataPointAcceptorPeer $ dataPointAcceptorApply f 0 n

prop_channel_ST_DataPointForward
Expand Down

0 comments on commit db42685

Please sign in to comment.