diff --git a/IHP/DataSync/Controller.hs b/IHP/DataSync/Controller.hs index b53c59117..869024f15 100644 --- a/IHP/DataSync/Controller.hs +++ b/IHP/DataSync/Controller.hs @@ -37,7 +37,7 @@ instance ( initialState = DataSyncController run = do - setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty } + setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty, asyncs = [] } ensureRLSEnabled <- makeCachedEnsureRLSEnabled installTableChangeTriggers <- ChangeNotifications.makeCachedInstallTableChangeTriggers @@ -56,6 +56,8 @@ instance ( sendJSON DataSyncResult { result, requestId } handleMessage CreateDataSubscription { query, requestId } = do + ensureBelowSubscriptionsLimit + tableNameRLS <- ensureRLSEnabled (get #table query) subscriptionId <- UUID.nextRandom @@ -111,22 +113,23 @@ instance ( when isWatchingRecord do sendJSON DidDelete { subscriptionId, id } + let subscribe = PGListener.subscribeJSON (ChangeNotifications.channelName tableNameRLS) callback pgListener + let unsubscribe subscription = PGListener.unsubscribe subscription pgListener - channelSubscription <- pgListener - |> PGListener.subscribeJSON (ChangeNotifications.channelName tableNameRLS) callback + Exception.bracket subscribe unsubscribe \channelSubscription -> do + close <- MVar.newEmptyMVar + modifyIORef' ?state (\state -> state |> modify #subscriptions (HashMap.insert subscriptionId close)) - modifyIORef' ?state (\state -> state |> modify #subscriptions (HashMap.insert subscriptionId Subscription { id = subscriptionId, channelSubscription })) + sendJSON DidCreateDataSubscription { subscriptionId, requestId, result } - sendJSON DidCreateDataSubscription { subscriptionId, requestId, result } + MVar.takeMVar close handleMessage DeleteDataSubscription { requestId, subscriptionId } = do DataSyncReady { subscriptions } <- getState - let maybeSubscription :: Maybe Subscription = HashMap.lookup subscriptionId subscriptions + let (Just closeSignalMVar) = HashMap.lookup subscriptionId subscriptions -- Cancel table watcher - case maybeSubscription of - Just subscription -> pgListener |> PGListener.unsubscribe (get #channelSubscription subscription) - Nothing -> pure () + MVar.putMVar closeSignalMVar () modifyIORef' ?state (\state -> state |> modify #subscriptions (HashMap.delete subscriptionId)) @@ -260,35 +263,49 @@ instance ( ensureBelowTransactionLimit transactionId <- UUID.nextRandom - - (connection, localPool) <- ?modelContext - |> get #connectionPool - |> Pool.takeResource - let transaction = DataSyncTransaction - { id = transactionId - , connection - , releaseConnection = Pool.putResource localPool connection - } - let globalModelContext = ?modelContext - let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" () + let takeConnection = ?modelContext + |> get #connectionPool + |> Pool.takeResource + + let releaseConnection (connection, localPool) = do + PG.execute connection "ROLLBACK" () -- Make sure there's no pending transaction in case something went wrong + Pool.putResource localPool connection + + Exception.bracket takeConnection releaseConnection \(connection, localPool) -> do + transactionSignal <- MVar.newEmptyMVar + + let globalModelContext = ?modelContext + let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" () - modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction)) + let transaction = DataSyncTransaction + { id = transactionId + , connection + , close = transactionSignal + } - sendJSON DidStartTransaction { requestId, transactionId } + modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction)) + + sendJSON DidStartTransaction { requestId, transactionId } + + MVar.takeMVar transactionSignal + + modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId)) handleMessage RollbackTransaction { requestId, id } = do - sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" () + DataSyncTransaction { id, close } <- findTransactionById id - closeTransaction id + sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" () + MVar.putMVar close () sendJSON DidRollbackTransaction { requestId, transactionId = id } handleMessage CommitTransaction { requestId, id } = do - sqlExecWithRLSAndTransactionId (Just id) "COMMIT" () + DataSyncTransaction { id, close } <- findTransactionById id - closeTransaction id + sqlExecWithRLSAndTransactionId (Just id) "COMMIT" () + MVar.putMVar close () sendJSON DidCommitTransaction { requestId, transactionId = id } @@ -301,22 +318,24 @@ instance ( Right decodedMessage -> do let requestId = get #requestId decodedMessage - -- Handle the messages in an async way - -- This increases throughput as multiple queries can be fetched - -- in parallel - async do - result <- Exception.try (handleMessage decodedMessage) - - case result of - Left (e :: Exception.SomeException) -> do - let errorMessage = case fromException e of - Just (enhancedSqlError :: EnhancedSqlError) -> cs (get #sqlErrorMsg (get #sqlError enhancedSqlError)) - Nothing -> cs (displayException e) - Log.error (tshow e) - sendJSON DataSyncError { requestId, errorMessage } - Right result -> pure () - - pure () + Exception.mask \restore -> do + -- Handle the messages in an async way + -- This increases throughput as multiple queries can be fetched + -- in parallel + handlerProcess <- async $ restore do + result <- Exception.try (handleMessage decodedMessage) + + case result of + Left (e :: Exception.SomeException) -> do + let errorMessage = case fromException e of + Just (enhancedSqlError :: EnhancedSqlError) -> cs (get #sqlErrorMsg (get #sqlError enhancedSqlError)) + Nothing -> cs (displayException e) + Log.error (tshow e) + sendJSON DataSyncError { requestId, errorMessage } + Right result -> pure () + + modifyIORef' ?state (\state -> state |> modify #asyncs (handlerProcess:)) + pure () Left errorMessage -> sendJSON FailedToDecodeMessageError { errorMessage = cs errorMessage } onClose = cleanupAllSubscriptions @@ -327,16 +346,7 @@ cleanupAllSubscriptions = do let pgListener = ?applicationContext |> get #pgListener case state of - DataSyncReady { subscriptions, transactions } -> do - let channelSubscriptions = subscriptions - |> HashMap.elems - |> map (get #channelSubscription) - forEach channelSubscriptions \channelSubscription -> do - pgListener |> PGListener.unsubscribe channelSubscription - - forEach (HashMap.elems transactions) (get #releaseConnection) - - pure () + DataSyncReady { asyncs } -> forEach asyncs uninterruptibleCancel _ -> pure () changesToValue :: [ChangeNotifications.Change] -> Value @@ -369,11 +379,6 @@ findTransactionById transactionId = do Just transaction -> pure transaction Nothing -> error "No transaction with that id" -closeTransaction transactionId = do - DataSyncTransaction { releaseConnection } <- findTransactionById transactionId - modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId)) - releaseConnection - -- | Allow max 10 concurrent transactions per connection to avoid running out of database connections -- -- Each transaction removes a database connection from the connection pool. If we don't limit the transactions, @@ -389,6 +394,14 @@ ensureBelowTransactionLimit = do when (transactionCount >= maxTransactionsPerConnection) do error ("You've reached the transaction limit of " <> tshow maxTransactionsPerConnection <> " transactions") +ensureBelowSubscriptionsLimit :: (?state :: IORef DataSyncController) => IO () +ensureBelowSubscriptionsLimit = do + subscriptions <- get #subscriptions <$> readIORef ?state + let subscriptionsCount = HashMap.size subscriptions + let maxSubscriptionsPerConnection = 128 + when (subscriptionsCount >= maxSubscriptionsPerConnection) do + error ("You've reached the subscriptions limit of " <> tshow maxSubscriptionsPerConnection <> " subscriptions") + sqlQueryWithRLSAndTransactionId :: ( ?modelContext :: ModelContext , PG.ToRow parameters @@ -423,8 +436,11 @@ sqlExecWithRLSAndTransactionId transactionId theQuery theParams = runInModelCont $(deriveFromJSON defaultOptions 'DataSyncQuery) $(deriveToJSON defaultOptions 'DataSyncResult) -instance SetField "subscriptions" DataSyncController (HashMap UUID Subscription) where +instance SetField "subscriptions" DataSyncController (HashMap UUID (MVar.MVar ())) where setField subscriptions record = record { subscriptions } instance SetField "transactions" DataSyncController (HashMap UUID DataSyncTransaction) where - setField transactions record = record { transactions } \ No newline at end of file + setField transactions record = record { transactions } + +instance SetField "asyncs" DataSyncController [Async ()] where + setField asyncs record = record { asyncs } \ No newline at end of file diff --git a/IHP/DataSync/Types.hs b/IHP/DataSync/Types.hs index d2b193e8a..c3a4fd9df 100644 --- a/IHP/DataSync/Types.hs +++ b/IHP/DataSync/Types.hs @@ -7,6 +7,7 @@ import IHP.DataSync.DynamicQuery import Data.HashMap.Strict (HashMap) import qualified IHP.PGListener as PGListener import qualified Database.PostgreSQL.Simple as PG +import Control.Concurrent.MVar as MVar data DataSyncMessage = DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int, transactionId :: !(Maybe UUID) } @@ -42,17 +43,17 @@ data DataSyncResponse | DidRollbackTransaction { requestId :: !Int, transactionId :: !UUID } | DidCommitTransaction { requestId :: !Int, transactionId :: !UUID } -data Subscription = Subscription { id :: !UUID, channelSubscription :: !PGListener.Subscription } data DataSyncTransaction = DataSyncTransaction { id :: !UUID , connection :: !PG.Connection - , releaseConnection :: IO () + , close :: MVar () } data DataSyncController = DataSyncController | DataSyncReady - { subscriptions :: !(HashMap UUID Subscription) + { subscriptions :: !(HashMap UUID (MVar.MVar ())) , transactions :: !(HashMap UUID DataSyncTransaction) + , asyncs :: ![Async ()] }