diff --git a/packages/toolkit/src/query/core/buildInitiate.ts b/packages/toolkit/src/query/core/buildInitiate.ts index d7a8731030..2d12b5cea1 100644 --- a/packages/toolkit/src/query/core/buildInitiate.ts +++ b/packages/toolkit/src/query/core/buildInitiate.ts @@ -1,14 +1,15 @@ import type { EndpointDefinitions, + QueryDefinition, MutationDefinition, QueryArgFrom, - QueryDefinition, ResultTypeFrom, } from '../endpointDefinitions' import { DefinitionType } from '../endpointDefinitions' -import type { MutationThunk, QueryThunk } from './buildThunks' -import type { AnyAction, SerializedError, ThunkAction } from '@reduxjs/toolkit' -import type { RootState, SubscriptionOptions } from './apiState' +import type { QueryThunk, MutationThunk } from './buildThunks' +import type { AnyAction, ThunkAction, SerializedError } from '@reduxjs/toolkit' +import type { SubscriptionOptions, RootState } from './apiState' +import { QueryStatus } from './apiState' import type { InternalSerializeQueryArgs } from '../defaultSerializeQueryArgs' import type { Api, ApiContext } from '../apiTypes' import type { ApiEndpointQuery } from './module' @@ -191,12 +192,10 @@ export function buildInitiate({ api: Api context: ApiContext }) { - // keep track of running queries by id const runningQueries: Record< string, - Record> + QueryActionCreatorResult | undefined > = {} - // keep track of running mutations by id const runningMutations: Record< string, MutationActionCreatorResult | undefined @@ -225,10 +224,7 @@ export function buildInitiate({ endpointDefinition, endpointName, }) - // TODO(manuel) this is not really what we want, because we don't know which of those thunks will actually resolve to the correct result - return Promise.all( - Object.values(runningQueries[queryCacheKey] || {}) - ).then((x) => x[0]) + return runningQueries[queryCacheKey] } else { return runningMutations[argOrRequestId] } @@ -236,9 +232,7 @@ export function buildInitiate({ function getRunningOperationPromises() { return [ - ...Object.values(runningQueries) - .map((x) => Object.values(x)) - .reduce((x, y) => x.concat(y)), + ...Object.values(runningQueries), ...Object.values(runningMutations), ].filter((t: T | undefined): t is T => !!t) } @@ -281,20 +275,27 @@ Features like automatic cache collection, automatic refetching etc. will not be originalArgs: arg, queryCacheKey, }) + const selector = ( + api.endpoints[endpointName] as ApiEndpointQuery + ).select(arg) + const thunkResult = dispatch(thunk) + const stateAfter = selector(getState()) + middlewareWarning(getState) const { requestId, abort } = thunkResult - const prevThunks = Object.values(runningQueries?.[queryCacheKey] || {}) + const skippedSynchronously = stateAfter.requestId !== requestId + + const runningQuery = runningQueries[queryCacheKey] - let promises: Promise[] = [...prevThunks, thunkResult] const statePromise: QueryActionCreatorResult = Object.assign( - Promise.all(promises).then(() => { - return ( - api.endpoints[endpointName] as ApiEndpointQuery - ).select(arg)(getState()) - }), + skippedSynchronously && !runningQuery + ? Promise.resolve(stateAfter) + : Promise.all([runningQuery, thunkResult]).then(() => + selector(getState()) + ), { arg, requestId, @@ -338,15 +339,13 @@ Features like automatic cache collection, automatic refetching etc. will not be } ) - if (!runningQueries.hasOwnProperty(queryCacheKey)) { - runningQueries[queryCacheKey] = {} + if (!runningQuery && !skippedSynchronously) { + runningQueries[queryCacheKey] = statePromise + statePromise.then(() => { + delete runningQueries[queryCacheKey] + }) } - runningQueries[queryCacheKey][requestId] = statePromise - statePromise.then(() => { - delete runningQueries?.[queryCacheKey]?.[requestId] - }) - return statePromise } return queryAction diff --git a/packages/toolkit/src/query/tests/buildInitiate.test.tsx b/packages/toolkit/src/query/tests/buildInitiate.test.tsx new file mode 100644 index 0000000000..b16182177a --- /dev/null +++ b/packages/toolkit/src/query/tests/buildInitiate.test.tsx @@ -0,0 +1,54 @@ +import { createApi } from '../core' +import { fakeBaseQuery } from '../fakeBaseQuery' +import { setupApiStore } from './helpers' + +let calls = 0 +const api = createApi({ + baseQuery: fakeBaseQuery(), + endpoints: (build) => ({ + increment: build.query({ + async queryFn() { + const data = calls++ + await Promise.resolve() + return { data } + }, + }), + }), +}) + +const storeRef = setupApiStore(api) + +test('multiple synchonrous initiate calls with pre-existing cache entry', async () => { + const { store, api } = storeRef + // seed the store + const firstValue = await store.dispatch(api.endpoints.increment.initiate()) + + expect(firstValue).toMatchObject({ data: 0, status: 'fulfilled' }) + + // dispatch another increment + const secondValuePromise = store.dispatch(api.endpoints.increment.initiate()) + // and one with a forced refresh + const thirdValuePromise = store.dispatch( + api.endpoints.increment.initiate(undefined, { forceRefetch: true }) + ) + // and another increment + const fourthValuePromise = store.dispatch(api.endpoints.increment.initiate()) + + const secondValue = await secondValuePromise + const thirdValue = await thirdValuePromise + const fourthValue = await fourthValuePromise + + expect(secondValue).toMatchObject({ + data: firstValue.data, + status: 'fulfilled', + requestId: firstValue.requestId, + }) + + expect(thirdValue).toMatchObject({ data: 1, status: 'fulfilled' }) + expect(thirdValue.requestId).not.toBe(firstValue.requestId) + expect(fourthValue).toMatchObject({ + data: thirdValue.data, + status: 'fulfilled', + requestId: thirdValue.requestId, + }) +})