Skip to content

Commit

Permalink
Simplify session vault key handling
Browse files Browse the repository at this point in the history
Moved the session vault key from the ApplicationContext and RequestContext data structure to a global variable. This is the suggested way by the WAI developers.

See https://www.yesodweb.com/blog/2015/10/using-wais-vault
  • Loading branch information
mpscholten committed Mar 15, 2024
1 parent e73425a commit ffbb5c6
Show file tree
Hide file tree
Showing 14 changed files with 33 additions and 33 deletions.
2 changes: 0 additions & 2 deletions IHP/ApplicationContext.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ module IHP.ApplicationContext where

import IHP.Prelude
import Network.Wai.Session (Session)
import qualified Data.Vault.Lazy as Vault
import IHP.AutoRefresh.Types (AutoRefreshServer)
import IHP.FrameworkConfig (FrameworkConfig)
import IHP.PGListener (PGListener)

data ApplicationContext = ApplicationContext
{ modelContext :: !ModelContext
, session :: !(Vault.Key (Session IO ByteString ByteString))
, autoRefreshServer :: !(IORef AutoRefreshServer)
, frameworkConfig :: !FrameworkConfig
, pgListener :: PGListener
Expand Down
1 change: 0 additions & 1 deletion IHP/Controller/RequestContext.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ data RequestContext = RequestContext
{ request :: Request
, respond :: Respond
, requestBody :: RequestBody
, vault :: (Vault.Key (Session IO ByteString ByteString))
, frameworkConfig :: FrameworkConfig
}
11 changes: 9 additions & 2 deletions IHP/Controller/Session.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ module IHP.Controller.Session
, getSessionEither
, deleteSession
, getSessionAndClear
, sessionVaultKey
) where

import IHP.Prelude
Expand All @@ -36,6 +37,8 @@ import qualified Network.Wai as Wai
import qualified Data.Serialize as Serialize
import Data.Serialize (Serialize)
import Data.Serialize.Text ()
import qualified Network.Wai.Session
import System.IO.Unsafe (unsafePerformIO)

-- | Types of possible errors as a result of
-- requesting a value from the session storage
Expand Down Expand Up @@ -161,5 +164,9 @@ sessionVault = case vaultLookup of
Just session -> session
Nothing -> error "sessionInsert: The session vault is missing in the request"
where
RequestContext { request, vault } = ?context.requestContext
vaultLookup = Vault.lookup vault (Wai.vault request)
RequestContext { request } = ?context.requestContext
vaultLookup = Vault.lookup sessionVaultKey request.vault

sessionVaultKey :: Vault.Key (Network.Wai.Session.Session IO ByteString ByteString)
sessionVaultKey = unsafePerformIO Vault.newKey
{-# NOINLINE sessionVaultKey #-}
4 changes: 2 additions & 2 deletions IHP/ControllerSupport.hs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ requestBodyJSON =

{-# INLINE createRequestContext #-}
createRequestContext :: ApplicationContext -> Request -> Respond -> IO RequestContext
createRequestContext ApplicationContext { session, frameworkConfig } request respond = do
createRequestContext ApplicationContext { frameworkConfig } request respond = do
let contentType = lookup hContentType (requestHeaders request)
requestBody <- case contentType of
"application/json" -> do
Expand All @@ -270,7 +270,7 @@ createRequestContext ApplicationContext { session, frameworkConfig } request res
(params, files) <- WaiParse.parseRequestBodyEx frameworkConfig.parseRequestBodyOptions WaiParse.lbsBackEnd request
pure RequestContext.FormBody { .. }

pure RequestContext.RequestContext { request, respond, requestBody, vault = session, frameworkConfig }
pure RequestContext.RequestContext { request, respond, requestBody, frameworkConfig }


-- | Returns a custom config parameter
Expand Down
2 changes: 1 addition & 1 deletion IHP/RouterSupport.hs
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ withPrefix prefix routes = string prefix >> choice (map (\r -> r <* endOfInput)

frontControllerToWAIApp :: forall app (autoRefreshApp :: Type). (?applicationContext :: ApplicationContext, FrontController app, WSApp autoRefreshApp, Typeable autoRefreshApp, InitControllerContext ()) => Middleware -> app -> Application -> Application
frontControllerToWAIApp middleware application notFoundAction request respond = do
let requestContext = RequestContext { request, respond, requestBody = FormBody { params = [], files = [] }, vault = ?applicationContext.session, frameworkConfig = ?applicationContext.frameworkConfig }
let requestContext = RequestContext { request, respond, requestBody = FormBody { params = [], files = [] }, frameworkConfig = ?applicationContext.frameworkConfig }

let ?context = requestContext

Expand Down
13 changes: 6 additions & 7 deletions IHP/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Network.Wai.Middleware.MethodOverridePost (methodOverridePost)
import Network.Wai.Session (withSession, Session)
import Network.Wai.Session.ClientSession (clientsessionStore)
import qualified Web.ClientSession as ClientSession
import IHP.Controller.Session (sessionVaultKey)
import qualified Data.Vault.Lazy as Vault
import IHP.ApplicationContext
import qualified IHP.ControllerSupport as ControllerSupport
Expand Down Expand Up @@ -48,14 +49,12 @@ run configBuilder = do

withInitalizers frameworkConfig modelContext do
withPGListener \pgListener -> do
sessionVault <- Vault.newKey

autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)

let ?modelContext = modelContext
let ?applicationContext = ApplicationContext { modelContext = ?modelContext, session = sessionVault, autoRefreshServer, frameworkConfig, pgListener }
let ?applicationContext = ApplicationContext { modelContext = ?modelContext, autoRefreshServer, frameworkConfig, pgListener }

sessionMiddleware <- initSessionMiddleware sessionVault frameworkConfig
sessionMiddleware <- initSessionMiddleware frameworkConfig
staticApp <- initStaticApp frameworkConfig
let corsMiddleware = initCorsMiddleware frameworkConfig
let requestLoggerMiddleware = frameworkConfig.requestLoggerMiddleware
Expand Down Expand Up @@ -108,8 +107,8 @@ initStaticApp frameworkConfig = do

pure (Static.staticApp appSettings)

initSessionMiddleware :: Vault.Key (Session IO ByteString ByteString) -> FrameworkConfig -> IO Middleware
initSessionMiddleware sessionVault FrameworkConfig { sessionCookie } = do
initSessionMiddleware :: FrameworkConfig -> IO Middleware
initSessionMiddleware FrameworkConfig { sessionCookie } = do
let path = "Config/client_session_key.aes"

hasSessionSecretEnvVar <- EnvVar.hasEnvVar "IHP_SESSION_SECRET"
Expand All @@ -118,7 +117,7 @@ initSessionMiddleware sessionVault FrameworkConfig { sessionCookie } = do
if hasSessionSecretEnvVar || not doesConfigDirectoryExist
then ClientSession.getKeyEnv "IHP_SESSION_SECRET"
else ClientSession.getKey path
let sessionMiddleware :: Middleware = withSession store "SESSION" sessionCookie sessionVault
let sessionMiddleware :: Middleware = withSession store "SESSION" sessionCookie sessionVaultKey
pure sessionMiddleware

initCorsMiddleware :: FrameworkConfig -> Middleware
Expand Down
17 changes: 7 additions & 10 deletions IHP/Test/Mocking.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import qualified Network.Wai.Session
import qualified Data.Serialize as Serialize
import qualified Control.Exception as Exception
import qualified IHP.PGListener as PGListener
import IHP.Controller.Session (sessionVaultKey)

type ContextParameters application = (?applicationContext :: ApplicationContext, ?context :: RequestContext, ?modelContext :: ModelContext, ?application :: application, InitControllerContext application, ?mocking :: MockContext application)

Expand All @@ -58,17 +59,15 @@ withIHPApp application configBuilder hspecAction = do
withTestDatabase \testDatabase -> do
modelContext <- createModelContext dbPoolIdleTime dbPoolMaxConnections (testDatabase.url) logger

session <- Vault.newKey
pgListener <- PGListener.init modelContext
autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)
let sessionVault = Vault.insert session mempty Vault.empty
let applicationContext = ApplicationContext { modelContext = modelContext, session, autoRefreshServer, frameworkConfig, pgListener }
let sessionVault = Vault.insert sessionVaultKey mempty Vault.empty
let applicationContext = ApplicationContext { modelContext = modelContext, autoRefreshServer, frameworkConfig, pgListener }

let requestContext = RequestContext
{ request = defaultRequest {vault = sessionVault}
, requestBody = FormBody [] []
, respond = const (pure ResponseReceived)
, vault = session
, frameworkConfig = frameworkConfig }

(hspecAction MockContext { .. })
Expand All @@ -81,17 +80,15 @@ mockContextNoDatabase application configBuilder = do
logger <- newLogger def { level = Warn } -- don't log queries
modelContext <- createModelContext dbPoolIdleTime dbPoolMaxConnections databaseUrl logger

session <- Vault.newKey
let sessionVault = Vault.insert session mempty Vault.empty
let sessionVault = Vault.insert sessionVaultKey mempty Vault.empty
pgListener <- PGListener.init modelContext
autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)
let applicationContext = ApplicationContext { modelContext = modelContext, session, autoRefreshServer, frameworkConfig, pgListener }
let applicationContext = ApplicationContext { modelContext = modelContext, autoRefreshServer, frameworkConfig, pgListener }

let requestContext = RequestContext
{ request = defaultRequest {vault = sessionVault}
, requestBody = FormBody [] []
, respond = \resp -> pure ResponseReceived
, vault = session
, frameworkConfig = frameworkConfig }

pure MockContext{..}
Expand Down Expand Up @@ -230,8 +227,8 @@ withUser user callback =

insertSession key value = pure ()

newVault = Vault.insert vaultKey newSession (Wai.vault request)
RequestContext { request, vault = vaultKey } = ?mocking.requestContext
newVault = Vault.insert sessionVaultKey newSession (Wai.vault request)
RequestContext { request } = ?mocking.requestContext

sessionValue = Serialize.encode (user.id)
sessionKey = cs (Session.sessionKey @user)
Expand Down
2 changes: 1 addition & 1 deletion Test/Controller/AccessDeniedSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ config = do
makeApplication :: (?applicationContext :: ApplicationContext) => IO Application
makeApplication = do
store <- Session.mapStore_
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey
pure (sessionMiddleware $ (Server.application handleNotFound) (\app -> app))

assertAccessDenied :: SResponse -> IO ()
Expand Down
2 changes: 1 addition & 1 deletion Test/Controller/CookieSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ createControllerContext = do
let
requestBody = FormBody { params = [], files = [] }
request = Wai.defaultRequest
requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" }
requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" }
let ?requestContext = requestContext
newControllerContext
2 changes: 1 addition & 1 deletion Test/Controller/NotFoundSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ config = do
makeApplication :: (?applicationContext :: ApplicationContext) => IO Application
makeApplication = do
store <- Session.mapStore_
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey
pure (sessionMiddleware $ (Server.application handleNotFound) (\app -> app))

assertNotFound :: SResponse -> IO ()
Expand Down
4 changes: 2 additions & 2 deletions Test/Controller/ParamSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,14 @@ createControllerContextWithParams params =
let
requestBody = FormBody { params, files = [] }
request = Wai.defaultRequest
requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" }
requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" }
in FrozenControllerContext { requestContext, customFields = TypeMap.empty }

createControllerContextWithJson params =
let
requestBody = JSONBody { jsonPayload = Just (json params), rawPayload = cs params }
request = Wai.defaultRequest
requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" }
requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" }
in FrozenControllerContext { requestContext, customFields = TypeMap.empty }

json :: Text -> Aeson.Value
Expand Down
2 changes: 1 addition & 1 deletion Test/View/CSSFrameworkSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -721,5 +721,5 @@ createControllerContextWithCSSFramework cssFramework = do
option cssFramework
let requestBody = FormBody { params = [], files = [] }
let request = Wai.defaultRequest
let requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = frameworkConfig }
let requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = frameworkConfig }
pure FrozenControllerContext { requestContext, customFields = TypeMap.empty }
2 changes: 1 addition & 1 deletion Test/View/FormSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ createControllerContext = do
frameworkConfig <- FrameworkConfig.buildFrameworkConfig (pure ())
let requestBody = FormBody { params = [], files = [] }
let request = Wai.defaultRequest
let requestContext = RequestContext { request, respond = undefined, requestBody, vault = undefined, frameworkConfig = frameworkConfig }
let requestContext = RequestContext { request, respond = undefined, requestBody, frameworkConfig = frameworkConfig }
pure FrozenControllerContext { requestContext, customFields = mempty }

data Project' = Project {id :: (Id' "projects"), title :: Text, meta :: MetaBag} deriving (Eq, Show)
Expand Down
2 changes: 1 addition & 1 deletion Test/ViewSupportSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ config = do
makeApplication :: (?applicationContext :: ApplicationContext) => IO Application
makeApplication = do
store <- Session.mapStore_
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey
pure (sessionMiddleware $ (Server.application handleNotFound (\app -> app)))

tests :: Spec
Expand Down

0 comments on commit ffbb5c6

Please sign in to comment.