Skip to content

Commit

Permalink
add api.util.selectInvalidatedBy (#1665)
Browse files Browse the repository at this point in the history
  • Loading branch information
phryneas authored Oct 29, 2021
1 parent 0a16fb5 commit 5e36d3a
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { isAnyOf, isFulfilled, isRejectedWithValue } from '@reduxjs/toolkit'

import type { FullTagDescription } from '../../endpointDefinitions'
import { calculateProvidedBy } from '../../endpointDefinitions'
import { flatten } from '../../utils'
import type { QueryCacheKey } from '../apiState'
import { QueryStatus } from '../apiState'
import { calculateProvidedByThunk } from '../buildThunks'
Expand Down Expand Up @@ -59,39 +58,27 @@ export const build: SubMiddlewareBuilder = ({

function invalidateTags(
tags: readonly FullTagDescription<string>[],
api: SubMiddlewareApi
mwApi: SubMiddlewareApi
) {
const state = api.getState()[reducerPath]
const rootState = mwApi.getState()
const state = rootState[reducerPath]

const toInvalidate = new Set<QueryCacheKey>()
for (const tag of tags) {
const provided = state.provided[tag.type]
if (!provided) {
continue
}

let invalidateSubscriptions =
(tag.id !== undefined
? // id given: invalidate all queries that provide this type & id
provided[tag.id]
: // no id: invalidate all queries that provide this type
flatten(Object.values(provided))) ?? []

for (const invalidate of invalidateSubscriptions) {
toInvalidate.add(invalidate)
}
}
const toInvalidate = api.util.selectInvalidatedBy(rootState, tags)

context.batch(() => {
const valuesArray = Array.from(toInvalidate.values())
for (const queryCacheKey of valuesArray) {
for (const { queryCacheKey } of valuesArray) {
const querySubState = state.queries[queryCacheKey]
const subscriptionSubState = state.subscriptions[queryCacheKey]
if (querySubState && subscriptionSubState) {
if (Object.keys(subscriptionSubState).length === 0) {
api.dispatch(removeQueryResult({ queryCacheKey }))
mwApi.dispatch(
removeQueryResult({
queryCacheKey: queryCacheKey as QueryCacheKey,
})
)
} else if (querySubState.status !== QueryStatus.uninitialized) {
api.dispatch(refetchQuery(querySubState, queryCacheKey))
mwApi.dispatch(refetchQuery(querySubState, queryCacheKey))
} else {
}
}
Expand Down
50 changes: 49 additions & 1 deletion packages/toolkit/src/query/core/buildSelectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type {
QuerySubState,
RootState as _RootState,
RequestStatusFlags,
QueryCacheKey,
} from './apiState'
import { QueryStatus, getRequestStatusFlags } from './apiState'
import type {
Expand All @@ -13,9 +14,12 @@ import type {
QueryArgFrom,
TagTypesFrom,
ReducerPathFrom,
TagDescription,
} from '../endpointDefinitions'
import { expandTagDescription } from '../endpointDefinitions'
import type { InternalSerializeQueryArgs } from '../defaultSerializeQueryArgs'
import { getMutationCacheKey } from './buildSlice'
import { flatten } from '../utils'

export type SkipToken = typeof skipToken
/**
Expand Down Expand Up @@ -125,7 +129,7 @@ export function buildSelectors<
}) {
type RootState = _RootState<Definitions, string, string>

return { buildQuerySelector, buildMutationSelector }
return { buildQuerySelector, buildMutationSelector, selectInvalidatedBy }

function withRequestFlags<T extends { status: QueryStatus }>(
substate: T
Expand Down Expand Up @@ -193,4 +197,48 @@ export function buildSelectors<
return createSelector(selectMutationSubstate, withRequestFlags)
}
}

function selectInvalidatedBy(
state: RootState,
tags: ReadonlyArray<TagDescription<string>>
): Array<{
endpointName: string
originalArgs: any
queryCacheKey: QueryCacheKey
}> {
const apiState = state[reducerPath]
const toInvalidate = new Set<QueryCacheKey>()
for (const tag of tags.map(expandTagDescription)) {
const provided = apiState.provided[tag.type]
if (!provided) {
continue
}

let invalidateSubscriptions =
(tag.id !== undefined
? // id given: invalidate all queries that provide this type & id
provided[tag.id]
: // no id: invalidate all queries that provide this type
flatten(Object.values(provided))) ?? []

for (const invalidate of invalidateSubscriptions) {
toInvalidate.add(invalidate)
}
}

return flatten(
Array.from(toInvalidate.values()).map((queryCacheKey) => {
const querySubState = apiState.queries[queryCacheKey]
return querySubState
? [
{
queryCacheKey,
endpointName: querySubState.endpointName!,
originalArgs: querySubState.originalArgs,
},
]
: []
})
)
}
}
24 changes: 18 additions & 6 deletions packages/toolkit/src/query/core/module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import type {
QueryDefinition,
MutationDefinition,
AssertTagTypes,
FullTagDescription,
TagDescription,
} from '../endpointDefinitions'
import { isQueryDefinition, isMutationDefinition } from '../endpointDefinitions'
import type {
Expand Down Expand Up @@ -274,9 +274,18 @@ declare module '../apiTypes' {
* ```
*/
invalidateTags: ActionCreatorWithPayload<
Array<TagTypes | FullTagDescription<TagTypes>>,
Array<TagDescription<TagTypes>>,
string
>

selectInvalidatedBy: (
state: RootState<Definitions, string, ReducerPath>,
tags: ReadonlyArray<TagDescription<TagTypes>>
) => Array<{
endpointName: string
originalArgs: any
queryCacheKey: string
}>
}
/**
* Endpoints based on the input endpoints provided to `createApi`, containing `select` and `action matchers`.
Expand Down Expand Up @@ -463,10 +472,13 @@ export const coreModule = (): Module<CoreModule> => ({

safeAssign(api, { reducer: reducer as any, middleware })

const { buildQuerySelector, buildMutationSelector } = buildSelectors({
serializeQueryArgs: serializeQueryArgs as any,
reducerPath,
})
const { buildQuerySelector, buildMutationSelector, selectInvalidatedBy } =
buildSelectors({
serializeQueryArgs: serializeQueryArgs as any,
reducerPath,
})

safeAssign(api.util, { selectInvalidatedBy })

const {
buildInitiateQuery,
Expand Down
2 changes: 1 addition & 1 deletion packages/toolkit/src/query/endpointDefinitions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ function isFunction<T>(t: T): t is Extract<T, Function> {
return typeof t === 'function'
}

function expandTagDescription(
export function expandTagDescription(
description: TagDescription<string>
): FullTagDescription<string> {
return typeof description === 'string' ? { type: description } : description
Expand Down
50 changes: 46 additions & 4 deletions packages/toolkit/src/query/tests/invalidation.test.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import { createApi, fakeBaseQuery } from '@reduxjs/toolkit/query'
import { setupApiStore, waitMs } from './helpers'
import type { ResultDescription } from '@reduxjs/toolkit/dist/query/endpointDefinitions'
import type { TagDescription } from '@reduxjs/toolkit/dist/query/endpointDefinitions'
import { waitFor } from '@testing-library/react'

const tagTypes = ['apple', 'pear', 'banana', 'tomato'] as const
const tagTypes = [
'apple',
'pear',
'banana',
'tomato',
'cat',
'dog',
'giraffe',
] as const
type TagTypes = typeof tagTypes[number]
type Tags = ResultDescription<TagTypes, any, any, any>
type Tags = TagDescription<TagTypes>[]

/** providesTags, invalidatesTags, shouldInvalidate */
const caseMatrix: [Tags, Tags, boolean][] = [
Expand Down Expand Up @@ -62,8 +71,9 @@ test.each(caseMatrix)(
let queryCount = 0
const {
store,
api,
api: {
endpoints: { invalidating, providing },
endpoints: { invalidating, providing, unrelated },
},
} = setupApiStore(
createApi({
Expand All @@ -77,6 +87,12 @@ test.each(caseMatrix)(
},
providesTags,
}),
unrelated: build.query<unknown, void>({
queryFn() {
return { data: {} }
},
providesTags: ['cat', 'dog', { type: 'giraffe', id: 8 }],
}),
invalidating: build.mutation<unknown, void>({
queryFn() {
return { data: {} }
Expand All @@ -88,7 +104,33 @@ test.each(caseMatrix)(
)

store.dispatch(providing.initiate())
store.dispatch(unrelated.initiate())
expect(queryCount).toBe(1)
await waitFor(() => {
expect(api.endpoints.providing.select()(store.getState()).status).toBe(
'fulfilled'
)
expect(api.endpoints.unrelated.select()(store.getState()).status).toBe(
'fulfilled'
)
})
const toInvalidate = api.util.selectInvalidatedBy(
store.getState(),
invalidatesTags
)

if (shouldInvalidate) {
expect(toInvalidate).toEqual([
{
queryCacheKey: 'providing(undefined)',
endpointName: 'providing',
originalArgs: undefined,
},
])
} else {
expect(toInvalidate).toEqual([])
}

store.dispatch(invalidating.initiate())
expect(queryCount).toBe(1)
await waitMs(2)
Expand Down

0 comments on commit 5e36d3a

Please sign in to comment.