diff --git a/Network/Socket/ByteString.hsc b/Network/Socket/ByteString.hsc index 2b2e0645..8469baaf 100644 --- a/Network/Socket/ByteString.hsc +++ b/Network/Socket/ByteString.hsc @@ -1,5 +1,6 @@ {-# OPTIONS_GHC -fno-warn-warnings-deprecations #-} {-# LANGUAGE CPP, ForeignFunctionInterface #-} +{-# LANGUAGE OverloadedStrings #-} #include "HsNet.h" @@ -41,6 +42,7 @@ module Network.Socket.ByteString , recvFrom ) where +import Control.Concurrent (threadWaitWrite) import Control.Exception as E (catch, throwIO) import Control.Monad (when) import Data.ByteString (ByteString) @@ -92,9 +94,11 @@ send sock xs = unsafeUseAsCStringLen xs $ \(str, len) -> sendAll :: Socket -- ^ Connected socket -> ByteString -- ^ Data to send -> IO () +sendAll _ "" = return () sendAll sock bs = do - sent <- send sock bs - when (sent < B.length bs) $ sendAll sock (B.drop sent bs) + sent <- send sock bs + when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock + when (sent >= 0) $ sendAll sock $ B.drop sent bs -- | Send data to the socket. The recipient can be specified -- explicitly, so the socket need not be in a connected state. @@ -121,9 +125,11 @@ sendAllTo :: Socket -- ^ Socket -> ByteString -- ^ Data to send -> SockAddr -- ^ Recipient address -> IO () +sendAllTo _ "" _ = return () sendAllTo sock xs addr = do sent <- sendTo sock xs addr - when (sent < B.length xs) $ sendAllTo sock (B.drop sent xs) addr + when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock + when (sent >= 0) $ sendAllTo sock (B.drop sent xs) addr -- ---------------------------------------------------------------------------- -- ** Vectored I/O @@ -159,9 +165,11 @@ sendMany :: Socket -- ^ Connected socket -> [ByteString] -- ^ Data to send -> IO () #if !defined(mingw32_HOST_OS) +sendMany _ [] = return () sendMany sock@(MkSocket fd _ _ _ _) cs = do sent <- sendManyInner - when (sent < totalLength cs) $ sendMany sock (remainingChunks sent cs) + when (sent == 0) $ threadWaitWrite $ fromIntegral fd + when (sent >= 0) $ sendMany sock (remainingChunks sent cs) where sendManyInner = liftM fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) -> @@ -185,9 +193,11 @@ sendManyTo :: Socket -- ^ Socket -> SockAddr -- ^ Recipient address -> IO () #if !defined(mingw32_HOST_OS) +sendManyTo _ [] _ = return () sendManyTo sock@(MkSocket fd _ _ _ _) cs addr = do sent <- liftM fromIntegral sendManyToInner - when (sent < totalLength cs) $ sendManyTo sock (remainingChunks sent cs) addr + when (sent == 0) $ threadWaitWrite $ fromIntegral fd + when (sent >= 0) $ sendManyTo sock (remainingChunks sent cs) addr where sendManyToInner = withSockAddr addr $ \addrPtr addrSize -> @@ -258,10 +268,6 @@ remainingChunks i (x:xs) where len = B.length x --- | @totalLength cs@ is the sum of the lengths of the chunks in the list @cs@. -totalLength :: [ByteString] -> Int -totalLength = sum . map B.length - -- | @withIOVec cs f@ executes the computation @f@, passing as argument a pair -- consisting of a pointer to a temporarily allocated array of pointers to -- IOVec made from @cs@ and the number of pointers (@length cs@). diff --git a/Network/Socket/ByteString/Lazy/Posix.hs b/Network/Socket/ByteString/Lazy/Posix.hs index 5d78b97a..a95b31bd 100644 --- a/Network/Socket/ByteString/Lazy/Posix.hs +++ b/Network/Socket/ByteString/Lazy/Posix.hs @@ -7,8 +7,8 @@ module Network.Socket.ByteString.Lazy.Posix , sendAll ) where -import Control.Monad (liftM) -import Control.Monad (unless) +import Control.Concurrent (threadWaitWrite) +import Control.Monad (liftM, when) import qualified Data.ByteString.Lazy as L import Data.ByteString.Lazy.Internal (ByteString(..)) import Data.ByteString.Unsafe (unsafeUseAsCStringLen) @@ -21,6 +21,7 @@ import Network.Socket (Socket(..)) import Network.Socket.ByteString.IOVec (IOVec(IOVec)) import Network.Socket.ByteString.Internal (c_writev) import Network.Socket.Internal +import Network.Socket.Types (sockFd) -- ----------------------------------------------------------------------------- -- Sending @@ -53,5 +54,5 @@ sendAll :: Socket -- ^ Connected socket -> IO () sendAll sock bs = do sent <- send sock bs - let bs' = L.drop sent bs - unless (L.null bs') $ sendAll sock bs' + when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock + when (sent >= 0) $ sendAll sock $ L.drop sent bs diff --git a/Network/Socket/ByteString/Lazy/Windows.hs b/Network/Socket/ByteString/Lazy/Windows.hs index ef2e3be0..968d4675 100644 --- a/Network/Socket/ByteString/Lazy/Windows.hs +++ b/Network/Socket/ByteString/Lazy/Windows.hs @@ -7,12 +7,14 @@ module Network.Socket.ByteString.Lazy.Windows ) where import Control.Applicative ((<$>)) -import Control.Monad (unless) +import Control.Concurrent (threadWaitWrite) +import Control.Monad (when) import qualified Data.ByteString as S import qualified Data.ByteString.Lazy as L import Data.Int (Int64) import Network.Socket (Socket(..)) +import Network.Socket.Types (sockFd) import qualified Network.Socket.ByteString as Socket -- ----------------------------------------------------------------------------- @@ -32,5 +34,5 @@ sendAll :: Socket -- ^ Connected socket -> IO () sendAll sock bs = do sent <- send sock bs - let bs' = L.drop sent bs - unless (L.null bs') $ sendAll sock bs' + when (sent == 0) $ threadWaitWrite $ fromIntegral $ sockFd sock + when (sent >= 0) $ sendAll sock $ L.drop sent bs