diff --git a/Network/HTTP/ReverseProxy.hs b/Network/HTTP/ReverseProxy.hs index c68021e..1ba477b 100644 --- a/Network/HTTP/ReverseProxy.hs +++ b/Network/HTTP/ReverseProxy.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE RankNTypes #-} module Network.HTTP.ReverseProxy ( -- * Types ProxyDest (..) @@ -40,7 +41,7 @@ module Network.HTTP.ReverseProxy import Blaze.ByteString.Builder (Builder, fromByteString, toLazyByteString) import Control.Applicative ((<$>), (<|>)) -import Control.Monad (unless) +import Control.Monad (unless, void) import Data.ByteString (ByteString) import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as S8 @@ -49,6 +50,7 @@ import qualified Data.CaseInsensitive as CI import Data.Conduit import qualified Data.Conduit.List as CL import qualified Data.Conduit.Network as DCN +import qualified Data.Conduit.Network.Unix as DCNU import Data.Functor.Identity (Identity (..)) import Data.IORef import Data.List.NonEmpty (NonEmpty (..)) @@ -57,11 +59,11 @@ import Data.Maybe (fromMaybe, listToMaybe) import Data.Monoid (mappend, mconcat, (<>)) import Data.Set (Set) import qualified Data.Set as Set -import Data.Streaming.Network (AppData, readLens) -import qualified Data.Text.Lazy as TL -import qualified Data.Text.Lazy.Encoding as TLE +import Data.Streaming.Network (AppData, HasReadWrite, readLens) import qualified Data.Text as T import qualified Data.Text.Encoding as TE +import qualified Data.Text.Lazy as TL +import qualified Data.Text.Lazy.Encoding as TLE import Data.Word8 (isSpace, _colon, _cr) import GHC.Generics (Generic) import Network.HTTP.Client (BodyReader, brRead) @@ -75,7 +77,12 @@ import UnliftIO (MonadIO, liftIO, MonadUnliftIO, data ProxyDest = ProxyDest { pdHost :: !ByteString , pdPort :: !Int - } deriving (Read, Show, Eq, Ord, Generic) + } + | ProxyDestUnix { + pdSocketPath :: FilePath + } + deriving (Read, Show, Eq, Ord, Generic) + -- | Set up a reverse proxy server, which will have a minimal overhead. -- @@ -306,11 +313,11 @@ renderHeaders req headers <> fromByteString ": " <> fromByteString y -tryWebSockets :: WaiProxySettings -> ByteString -> Int -> WAI.Request -> (WAI.Response -> IO b) -> IO b -> IO b -tryWebSockets wps host port req sendResponse fallback +tryWebSockets :: (HasReadWrite ad) => WaiProxySettings -> (forall a. (ad -> IO a) -> IO a) -> WAI.Request -> (WAI.Response -> IO b) -> IO b -> IO b +tryWebSockets wps runConduitClient req sendResponse fallback | wpsUpgradeToRaw wps req = sendResponse $ flip WAI.responseRaw backup $ \fromClientBody toClient -> - DCN.runTCPClient settings $ \server -> + runConduitClient $ \server -> let toServer = DCN.appSink server fromServer = DCN.appSource server fromClient = do @@ -330,7 +337,6 @@ tryWebSockets wps host port req sendResponse fallback where backup = WAI.responseLBS HT.status500 [("Content-Type", "text/plain")] "http-reverse-proxy detected WebSockets request, but server does not support responseRaw" - settings = DCN.clientSettings port host strippedHeaders :: Set HT.HeaderName strippedHeaders = Set.fromList @@ -377,52 +383,59 @@ waiProxyToSettings getDest wps' manager req0 sendResponse = do timeout us f >>= \case Just res -> return res Nothing -> sendResponse $ WAI.responseLBS HT.status500 [] "timeBound" - case edest of - Left app -> maybe id timeBound (lpsTimeBound lps) $ app req0 sendResponse - Right (ProxyDest host port, req, secure) -> tryWebSockets wps host port req sendResponse $ do - scb <- semiCachedBody (WAI.requestBody req) - let body = - case WAI.requestBodyLength req of - WAI.KnownLength i -> HC.RequestBodyStream (fromIntegral i) scb - WAI.ChunkedBody -> HC.RequestBodyStreamChunked scb + fallback req secure maybeHostPort = do + scb <- semiCachedBody (WAI.requestBody req) + let body = + case WAI.requestBodyLength req of + WAI.KnownLength i -> HC.RequestBodyStream (fromIntegral i) scb + WAI.ChunkedBody -> HC.RequestBodyStreamChunked scb - let req' = + let baseReq = #if MIN_VERSION_http_client(0, 5, 0) - HC.defaultRequest - { HC.checkResponse = \_ _ -> return () - , HC.responseTimeout = maybe HC.responseTimeoutNone HC.responseTimeoutMicro $ lpsTimeBound lps + HC.defaultRequest + { HC.checkResponse = \_ _ -> return () + , HC.responseTimeout = maybe HC.responseTimeoutNone HC.responseTimeoutMicro $ lpsTimeBound lps + } #else - def - { HC.checkStatus = \_ _ _ -> Nothing - , HC.responseTimeout = lpsTimeBound lps + def + { HC.checkStatus = \_ _ _ -> Nothing + , HC.responseTimeout = lpsTimeBound lps + } #endif - , HC.method = WAI.requestMethod req - , HC.secure = secure - , HC.host = host - , HC.port = port - , HC.path = WAI.rawPathInfo req - , HC.queryString = WAI.rawQueryString req - , HC.requestHeaders = fixReqHeaders wps req - , HC.requestBody = body - , HC.redirectCount = 0 - } - bracket - (try $ HC.responseOpen req' manager) - (either (const $ return ()) HC.responseClose) - $ \case - Left e -> wpsOnExc wps e req sendResponse - Right res -> do - let conduit = fromMaybe - (awaitForever (\bs -> yield (Chunk $ fromByteString bs) >> yield Flush)) - (wpsProcessBody wps req $ const () <$> res) - src = bodyReaderSource $ HC.responseBody res - sendResponse $ WAI.responseStream - (HC.responseStatus res) - (filter (\(key, _) -> not $ key `Set.member` strippedHeaders) $ HC.responseHeaders res) - (\sendChunk flush -> runConduit $ src .| conduit .| CL.mapM_ (\mb -> - case mb of - Flush -> flush - Chunk b -> sendChunk b)) + + let req' = baseReq { + HC.method = WAI.requestMethod req + , HC.secure = secure + , HC.host = maybe (HC.host baseReq) fst maybeHostPort + , HC.port = maybe (HC.port baseReq) snd maybeHostPort + , HC.path = WAI.rawPathInfo req + , HC.queryString = WAI.rawQueryString req + , HC.requestHeaders = fixReqHeaders wps req + , HC.requestBody = body + , HC.redirectCount = 0 + } + bracket + (try $ HC.responseOpen req' manager) + (either (const $ return ()) HC.responseClose) + $ \case + Left e -> wpsOnExc wps e req sendResponse + Right res -> do + let conduit = fromMaybe + (awaitForever (\bs -> yield (Chunk $ fromByteString bs) >> yield Flush)) + (wpsProcessBody wps req $ const () <$> res) + src = bodyReaderSource $ HC.responseBody res + sendResponse $ WAI.responseStream + (HC.responseStatus res) + (filter (\(key, _) -> not $ key `Set.member` strippedHeaders) $ HC.responseHeaders res) + (\sendChunk flush -> runConduit $ src .| conduit .| CL.mapM_ (\mb -> + case mb of + Flush -> flush + Chunk b -> sendChunk b)) + + case edest of + Left app -> maybe id timeBound (lpsTimeBound lps) $ app req0 sendResponse + Right (ProxyDest host port, req, secure) -> tryWebSockets wps (DCN.runTCPClient (DCN.clientSettings port host)) req sendResponse (fallback req secure (Just (host, port))) + Right (ProxyDestUnix socketPath, req, secure) -> tryWebSockets wps (DCNU.runUnixClient (DCNU.clientSettings socketPath)) req sendResponse (fallback req secure Nothing) -- | Introduce a minor level of caching to handle some basic -- retry cases inside http-client. But to avoid a DoS attack, diff --git a/stack.yaml b/stack.yaml index 8381b03..82823e4 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1 +1 @@ -resolver: lts-16.20 +resolver: lts-19.6