Skip to content

Commit

Permalink
fixing the case where sent == 0 (haskell#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazu-yamamoto committed May 25, 2018
1 parent 69b3c5c commit 545da23
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
24 changes: 15 additions & 9 deletions Network/Socket/ByteString.hsc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-}
{-# LANGUAGE CPP, ForeignFunctionInterface #-}
{-# LANGUAGE OverloadedStrings #-}

#include "HsNet.h"

Expand Down Expand Up @@ -41,6 +42,7 @@ module Network.Socket.ByteString
, recvFrom
) where

import Control.Concurrent (threadWaitWrite)
import Control.Exception as E (catch, throwIO)
import Control.Monad (when)
import Data.ByteString (ByteString)
Expand Down Expand Up @@ -92,9 +94,11 @@ send sock xs = unsafeUseAsCStringLen xs $ \(str, len) ->
sendAll :: Socket -- ^ Connected socket
-> ByteString -- ^ Data to send
-> IO ()
sendAll _ "" = return ()
sendAll sock bs = do
sent <- send sock bs
when (sent < B.length bs) $ sendAll sock (B.drop sent bs)
sent <- send sock bs
when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock
when (sent >= 0) $ sendAll sock $ B.drop sent bs

-- | Send data to the socket. The recipient can be specified
-- explicitly, so the socket need not be in a connected state.
Expand All @@ -121,9 +125,11 @@ sendAllTo :: Socket -- ^ Socket
-> ByteString -- ^ Data to send
-> SockAddr -- ^ Recipient address
-> IO ()
sendAllTo _ "" _ = return ()
sendAllTo sock xs addr = do
sent <- sendTo sock xs addr
when (sent < B.length xs) $ sendAllTo sock (B.drop sent xs) addr
when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock
when (sent >= 0) $ sendAllTo sock (B.drop sent xs) addr

-- ----------------------------------------------------------------------------
-- ** Vectored I/O
Expand Down Expand Up @@ -159,9 +165,11 @@ sendMany :: Socket -- ^ Connected socket
-> [ByteString] -- ^ Data to send
-> IO ()
#if !defined(mingw32_HOST_OS)
sendMany _ [] = return ()
sendMany sock@(MkSocket fd _ _ _ _) cs = do
sent <- sendManyInner
when (sent < totalLength cs) $ sendMany sock (remainingChunks sent cs)
when (sent == 0) $ threadWaitWrite $ fromIntegral fd
when (sent >= 0) $ sendMany sock (remainingChunks sent cs)
where
sendManyInner =
liftM fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) ->
Expand All @@ -185,9 +193,11 @@ sendManyTo :: Socket -- ^ Socket
-> SockAddr -- ^ Recipient address
-> IO ()
#if !defined(mingw32_HOST_OS)
sendManyTo _ [] _ = return ()
sendManyTo sock@(MkSocket fd _ _ _ _) cs addr = do
sent <- liftM fromIntegral sendManyToInner
when (sent < totalLength cs) $ sendManyTo sock (remainingChunks sent cs) addr
when (sent == 0) $ threadWaitWrite $ fromIntegral fd
when (sent >= 0) $ sendManyTo sock (remainingChunks sent cs) addr
where
sendManyToInner =
withSockAddr addr $ \addrPtr addrSize ->
Expand Down Expand Up @@ -258,10 +268,6 @@ remainingChunks i (x:xs)
where
len = B.length x

-- | @totalLength cs@ is the sum of the lengths of the chunks in the list @cs@.
totalLength :: [ByteString] -> Int
totalLength = sum . map B.length

-- | @withIOVec cs f@ executes the computation @f@, passing as argument a pair
-- consisting of a pointer to a temporarily allocated array of pointers to
-- IOVec made from @cs@ and the number of pointers (@length cs@).
Expand Down
9 changes: 5 additions & 4 deletions Network/Socket/ByteString/Lazy/Posix.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ module Network.Socket.ByteString.Lazy.Posix
, sendAll
) where

import Control.Monad (liftM)
import Control.Monad (unless)
import Control.Concurrent (threadWaitWrite)
import Control.Monad (liftM, when)
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Lazy.Internal (ByteString(..))
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
Expand All @@ -21,6 +21,7 @@ import Network.Socket (Socket(..))
import Network.Socket.ByteString.IOVec (IOVec(IOVec))
import Network.Socket.ByteString.Internal (c_writev)
import Network.Socket.Internal
import Network.Socket.Types (sockFd)

-- -----------------------------------------------------------------------------
-- Sending
Expand Down Expand Up @@ -53,5 +54,5 @@ sendAll :: Socket -- ^ Connected socket
-> IO ()
sendAll sock bs = do
sent <- send sock bs
let bs' = L.drop sent bs
unless (L.null bs') $ sendAll sock bs'
when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock
when (sent >= 0) $ sendAll sock $ L.drop sent bs
8 changes: 5 additions & 3 deletions Network/Socket/ByteString/Lazy/Windows.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ module Network.Socket.ByteString.Lazy.Windows
) where

import Control.Applicative ((<$>))
import Control.Monad (unless)
import Control.Concurrent (threadWaitWrite)
import Control.Monad (when)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Int (Int64)

import Network.Socket (Socket(..))
import Network.Socket.Types (sockFd)
import qualified Network.Socket.ByteString as Socket

-- -----------------------------------------------------------------------------
Expand All @@ -32,5 +34,5 @@ sendAll :: Socket -- ^ Connected socket
-> IO ()
sendAll sock bs = do
sent <- send sock bs
let bs' = L.drop sent bs
unless (L.null bs') $ sendAll sock bs'
when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock
when (sent >= 0) $ sendAll sock $ L.drop sent bs

0 comments on commit 545da23

Please sign in to comment.