From a6186e50488bd3e14b676bf4aedd0953be3f0574 Mon Sep 17 00:00:00 2001 From: Dominik Dorfmeister Date: Fri, 26 May 2023 15:27:38 +0200 Subject: [PATCH] feat(react-query-persist-client): await onuccess (#5473) * fix: have `onSuccess` return a promise too so that we can await it that way, invalidation or mutations that run in onSuccess will still keep isRestoring to be false, which avoids race conditions * test: test for awaiting onSuccess * docs: onSuccess --- docs/react/plugins/persistQueryClient.md | 3 +- .../src/PersistQueryClientProvider.tsx | 11 ++-- .../PersistQueryClientProvider.test.tsx | 60 +++++++++++++++++++ 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/docs/react/plugins/persistQueryClient.md b/docs/react/plugins/persistQueryClient.md index e673c33813..2f116599fc 100644 --- a/docs/react/plugins/persistQueryClient.md +++ b/docs/react/plugins/persistQueryClient.md @@ -212,10 +212,11 @@ ReactDOM.createRoot(rootElement).render( - `persistOptions: PersistQueryClientOptions` - all [options](#options) you can pass to [persistQueryClient](#persistqueryclient) minus the QueryClient itself -- `onSuccess?: () => void` +- `onSuccess?: () => Promise | unknown` - optional - will be called when the initial restore is finished - can be used to [resumePausedMutations](../reference/QueryClient#queryclientresumepausedmutations) + - if a Promise is returned, it will be awaited; restoring is seen as ongoing until then ### useIsRestoring diff --git a/packages/react-query-persist-client/src/PersistQueryClientProvider.tsx b/packages/react-query-persist-client/src/PersistQueryClientProvider.tsx index ecfc916a2b..fa4f8ec5b9 100644 --- a/packages/react-query-persist-client/src/PersistQueryClientProvider.tsx +++ b/packages/react-query-persist-client/src/PersistQueryClientProvider.tsx @@ -8,7 +8,7 @@ import { QueryClientProvider, IsRestoringProvider } from '@tanstack/react-query' export type PersistQueryClientProviderProps = QueryClientProviderProps & { persistOptions: Omit - onSuccess?: () => void + onSuccess?: () => Promise | unknown } export const PersistQueryClientProvider = ({ @@ -33,10 +33,13 @@ export const PersistQueryClientProvider = ({ queryClient: client, }) - promise.then(() => { + promise.then(async () => { if (!isStale) { - refs.current.onSuccess?.() - setIsRestoring(false) + try { + await refs.current.onSuccess?.() + } finally { + setIsRestoring(false) + } } }) diff --git a/packages/react-query-persist-client/src/__tests__/PersistQueryClientProvider.test.tsx b/packages/react-query-persist-client/src/__tests__/PersistQueryClientProvider.test.tsx index 8c3a0c59c4..a9013eefab 100644 --- a/packages/react-query-persist-client/src/__tests__/PersistQueryClientProvider.test.tsx +++ b/packages/react-query-persist-client/src/__tests__/PersistQueryClientProvider.test.tsx @@ -401,6 +401,66 @@ describe('PersistQueryClientProvider', () => { await waitFor(() => rendered.getByText('fetched')) }) + test('should await onSuccess after successful restoring', async () => { + const key = queryKey() + + const queryClient = createQueryClient() + await queryClient.prefetchQuery({ + queryKey: key, + queryFn: () => Promise.resolve('hydrated'), + }) + + const persister = createMockPersister() + + await persistQueryClientSave({ queryClient, persister }) + + queryClient.clear() + + const states: Array = [] + + function Page() { + const { data, fetchStatus } = useQuery({ + queryKey: key, + queryFn: async () => { + states.push('fetching') + await sleep(10) + states.push('fetched') + return 'fetched' + }, + }) + + return ( +
+

{data}

+

fetchStatus: {fetchStatus}

+
+ ) + } + + const rendered = render( + { + states.push('onSuccess') + await sleep(20) + states.push('onSuccess done') + }} + > + + , + ) + + await waitFor(() => rendered.getByText('hydrated')) + await waitFor(() => rendered.getByText('fetched')) + expect(states).toEqual([ + 'onSuccess', + 'onSuccess done', + 'fetching', + 'fetched', + ]) + }) + test('should remove cache after non-successful restoring', async () => { const key = queryKey() const consoleMock = vi.spyOn(console, 'error')