From 5e36d3a9cd78010f9028922a27562b571105a127 Mon Sep 17 00:00:00 2001 From: Lenz Weber Date: Fri, 29 Oct 2021 19:05:21 +0200 Subject: [PATCH] add `api.util.selectInvalidatedBy` (#1665) --- .../buildMiddleware/invalidationByTags.ts | 35 ++++--------- .../toolkit/src/query/core/buildSelectors.ts | 50 ++++++++++++++++++- packages/toolkit/src/query/core/module.ts | 24 ++++++--- .../toolkit/src/query/endpointDefinitions.ts | 2 +- .../src/query/tests/invalidation.test.tsx | 50 +++++++++++++++++-- 5 files changed, 125 insertions(+), 36 deletions(-) diff --git a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts index 734515648a..b84bfb5624 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts @@ -2,7 +2,6 @@ import { isAnyOf, isFulfilled, isRejectedWithValue } from '@reduxjs/toolkit' import type { FullTagDescription } from '../../endpointDefinitions' import { calculateProvidedBy } from '../../endpointDefinitions' -import { flatten } from '../../utils' import type { QueryCacheKey } from '../apiState' import { QueryStatus } from '../apiState' import { calculateProvidedByThunk } from '../buildThunks' @@ -59,39 +58,27 @@ export const build: SubMiddlewareBuilder = ({ function invalidateTags( tags: readonly FullTagDescription[], - api: SubMiddlewareApi + mwApi: SubMiddlewareApi ) { - const state = api.getState()[reducerPath] + const rootState = mwApi.getState() + const state = rootState[reducerPath] - const toInvalidate = new Set() - for (const tag of tags) { - const provided = state.provided[tag.type] - if (!provided) { - continue - } - - let invalidateSubscriptions = - (tag.id !== undefined - ? // id given: invalidate all queries that provide this type & id - provided[tag.id] - : // no id: invalidate all queries that provide this type - flatten(Object.values(provided))) ?? [] - - for (const invalidate of invalidateSubscriptions) { - toInvalidate.add(invalidate) - } - } + const toInvalidate = api.util.selectInvalidatedBy(rootState, tags) context.batch(() => { const valuesArray = Array.from(toInvalidate.values()) - for (const queryCacheKey of valuesArray) { + for (const { queryCacheKey } of valuesArray) { const querySubState = state.queries[queryCacheKey] const subscriptionSubState = state.subscriptions[queryCacheKey] if (querySubState && subscriptionSubState) { if (Object.keys(subscriptionSubState).length === 0) { - api.dispatch(removeQueryResult({ queryCacheKey })) + mwApi.dispatch( + removeQueryResult({ + queryCacheKey: queryCacheKey as QueryCacheKey, + }) + ) } else if (querySubState.status !== QueryStatus.uninitialized) { - api.dispatch(refetchQuery(querySubState, queryCacheKey)) + mwApi.dispatch(refetchQuery(querySubState, queryCacheKey)) } else { } } diff --git a/packages/toolkit/src/query/core/buildSelectors.ts b/packages/toolkit/src/query/core/buildSelectors.ts index dc022b901a..733c752bad 100644 --- a/packages/toolkit/src/query/core/buildSelectors.ts +++ b/packages/toolkit/src/query/core/buildSelectors.ts @@ -4,6 +4,7 @@ import type { QuerySubState, RootState as _RootState, RequestStatusFlags, + QueryCacheKey, } from './apiState' import { QueryStatus, getRequestStatusFlags } from './apiState' import type { @@ -13,9 +14,12 @@ import type { QueryArgFrom, TagTypesFrom, ReducerPathFrom, + TagDescription, } from '../endpointDefinitions' +import { expandTagDescription } from '../endpointDefinitions' import type { InternalSerializeQueryArgs } from '../defaultSerializeQueryArgs' import { getMutationCacheKey } from './buildSlice' +import { flatten } from '../utils' export type SkipToken = typeof skipToken /** @@ -125,7 +129,7 @@ export function buildSelectors< }) { type RootState = _RootState - return { buildQuerySelector, buildMutationSelector } + return { buildQuerySelector, buildMutationSelector, selectInvalidatedBy } function withRequestFlags( substate: T @@ -193,4 +197,48 @@ export function buildSelectors< return createSelector(selectMutationSubstate, withRequestFlags) } } + + function selectInvalidatedBy( + state: RootState, + tags: ReadonlyArray> + ): Array<{ + endpointName: string + originalArgs: any + queryCacheKey: QueryCacheKey + }> { + const apiState = state[reducerPath] + const toInvalidate = new Set() + for (const tag of tags.map(expandTagDescription)) { + const provided = apiState.provided[tag.type] + if (!provided) { + continue + } + + let invalidateSubscriptions = + (tag.id !== undefined + ? // id given: invalidate all queries that provide this type & id + provided[tag.id] + : // no id: invalidate all queries that provide this type + flatten(Object.values(provided))) ?? [] + + for (const invalidate of invalidateSubscriptions) { + toInvalidate.add(invalidate) + } + } + + return flatten( + Array.from(toInvalidate.values()).map((queryCacheKey) => { + const querySubState = apiState.queries[queryCacheKey] + return querySubState + ? [ + { + queryCacheKey, + endpointName: querySubState.endpointName!, + originalArgs: querySubState.originalArgs, + }, + ] + : [] + }) + ) + } } diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index 661d76eae1..db7abfb063 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -17,7 +17,7 @@ import type { QueryDefinition, MutationDefinition, AssertTagTypes, - FullTagDescription, + TagDescription, } from '../endpointDefinitions' import { isQueryDefinition, isMutationDefinition } from '../endpointDefinitions' import type { @@ -274,9 +274,18 @@ declare module '../apiTypes' { * ``` */ invalidateTags: ActionCreatorWithPayload< - Array>, + Array>, string > + + selectInvalidatedBy: ( + state: RootState, + tags: ReadonlyArray> + ) => Array<{ + endpointName: string + originalArgs: any + queryCacheKey: string + }> } /** * Endpoints based on the input endpoints provided to `createApi`, containing `select` and `action matchers`. @@ -463,10 +472,13 @@ export const coreModule = (): Module => ({ safeAssign(api, { reducer: reducer as any, middleware }) - const { buildQuerySelector, buildMutationSelector } = buildSelectors({ - serializeQueryArgs: serializeQueryArgs as any, - reducerPath, - }) + const { buildQuerySelector, buildMutationSelector, selectInvalidatedBy } = + buildSelectors({ + serializeQueryArgs: serializeQueryArgs as any, + reducerPath, + }) + + safeAssign(api.util, { selectInvalidatedBy }) const { buildInitiateQuery, diff --git a/packages/toolkit/src/query/endpointDefinitions.ts b/packages/toolkit/src/query/endpointDefinitions.ts index 0ba02c581f..dfd4565a8c 100644 --- a/packages/toolkit/src/query/endpointDefinitions.ts +++ b/packages/toolkit/src/query/endpointDefinitions.ts @@ -453,7 +453,7 @@ function isFunction(t: T): t is Extract { return typeof t === 'function' } -function expandTagDescription( +export function expandTagDescription( description: TagDescription ): FullTagDescription { return typeof description === 'string' ? { type: description } : description diff --git a/packages/toolkit/src/query/tests/invalidation.test.tsx b/packages/toolkit/src/query/tests/invalidation.test.tsx index afcd052e03..b5257f868f 100644 --- a/packages/toolkit/src/query/tests/invalidation.test.tsx +++ b/packages/toolkit/src/query/tests/invalidation.test.tsx @@ -1,10 +1,19 @@ import { createApi, fakeBaseQuery } from '@reduxjs/toolkit/query' import { setupApiStore, waitMs } from './helpers' -import type { ResultDescription } from '@reduxjs/toolkit/dist/query/endpointDefinitions' +import type { TagDescription } from '@reduxjs/toolkit/dist/query/endpointDefinitions' +import { waitFor } from '@testing-library/react' -const tagTypes = ['apple', 'pear', 'banana', 'tomato'] as const +const tagTypes = [ + 'apple', + 'pear', + 'banana', + 'tomato', + 'cat', + 'dog', + 'giraffe', +] as const type TagTypes = typeof tagTypes[number] -type Tags = ResultDescription +type Tags = TagDescription[] /** providesTags, invalidatesTags, shouldInvalidate */ const caseMatrix: [Tags, Tags, boolean][] = [ @@ -62,8 +71,9 @@ test.each(caseMatrix)( let queryCount = 0 const { store, + api, api: { - endpoints: { invalidating, providing }, + endpoints: { invalidating, providing, unrelated }, }, } = setupApiStore( createApi({ @@ -77,6 +87,12 @@ test.each(caseMatrix)( }, providesTags, }), + unrelated: build.query({ + queryFn() { + return { data: {} } + }, + providesTags: ['cat', 'dog', { type: 'giraffe', id: 8 }], + }), invalidating: build.mutation({ queryFn() { return { data: {} } @@ -88,7 +104,33 @@ test.each(caseMatrix)( ) store.dispatch(providing.initiate()) + store.dispatch(unrelated.initiate()) expect(queryCount).toBe(1) + await waitFor(() => { + expect(api.endpoints.providing.select()(store.getState()).status).toBe( + 'fulfilled' + ) + expect(api.endpoints.unrelated.select()(store.getState()).status).toBe( + 'fulfilled' + ) + }) + const toInvalidate = api.util.selectInvalidatedBy( + store.getState(), + invalidatesTags + ) + + if (shouldInvalidate) { + expect(toInvalidate).toEqual([ + { + queryCacheKey: 'providing(undefined)', + endpointName: 'providing', + originalArgs: undefined, + }, + ]) + } else { + expect(toInvalidate).toEqual([]) + } + store.dispatch(invalidating.initiate()) expect(queryCount).toBe(1) await waitMs(2)