diff --git a/packages/toolkit/src/query/core/buildInitiate.ts b/packages/toolkit/src/query/core/buildInitiate.ts index ee5206b4b1..eecff1c41f 100644 --- a/packages/toolkit/src/query/core/buildInitiate.ts +++ b/packages/toolkit/src/query/core/buildInitiate.ts @@ -9,7 +9,7 @@ import type { QueryThunk, MutationThunk } from './buildThunks' import type { AnyAction, ThunkAction, SerializedError } from '@reduxjs/toolkit' import type { QuerySubState, SubscriptionOptions, RootState } from './apiState' import type { InternalSerializeQueryArgs } from '../defaultSerializeQueryArgs' -import type { Api } from '../apiTypes' +import type { Api, ApiContext } from '../apiTypes' import type { ApiEndpointQuery } from './module' import type { BaseQueryError } from '../baseQueryTypes' @@ -177,14 +177,22 @@ export function buildInitiate({ queryThunk, mutationThunk, api, + context, }: { serializeQueryArgs: InternalSerializeQueryArgs queryThunk: QueryThunk mutationThunk: MutationThunk api: Api + context: ApiContext }) { - const runningQueries: Record | undefined> = {} - const runningMutations: Record | undefined> = {} + const runningQueries: Record< + string, + QueryActionCreatorResult | undefined + > = {} + const runningMutations: Record< + string, + MutationActionCreatorResult | undefined + > = {} const { unsubscribeQueryResult, @@ -195,12 +203,23 @@ export function buildInitiate({ buildInitiateQuery, buildInitiateMutation, getRunningOperationPromises, + getRunningQueryPromise, } function getRunningOperationPromises() { - return Object.values(runningQueries) - .concat(Object.values(runningMutations)) - .filter((t): t is Promise => !!t) + return [ + ...Object.values(runningQueries), + ...Object.values(runningMutations), + ].filter((t: T | undefined): t is T => !!t) + } + + function getRunningQueryPromise(endpointName: string, queryArgs: any) { + const queryCacheKey = serializeQueryArgs({ + queryArgs, + endpointDefinition: context.endpointDefinitions[endpointName], + endpointName, + }) + return runningQueries[queryCacheKey] } function middlewareWarning(getState: () => RootState<{}, string, string>) { @@ -243,7 +262,7 @@ Features like automatic cache collection, automatic refetching etc. will not be const thunkResult = dispatch(thunk) middlewareWarning(getState) const { requestId, abort } = thunkResult - const statePromise = Object.assign( + const statePromise: QueryActionCreatorResult = Object.assign( Promise.all([runningQueries[queryCacheKey], thunkResult]).then(() => (api.endpoints[endpointName] as ApiEndpointQuery).select( arg @@ -287,8 +306,8 @@ Features like automatic cache collection, automatic refetching etc. will not be }) if (!runningQueries[queryCacheKey]) { - runningQueries[queryCacheKey] = thunkResult - thunkResult.then(() => { + runningQueries[queryCacheKey] = statePromise + statePromise.then(() => { delete runningQueries[queryCacheKey] }) } @@ -325,13 +344,11 @@ Features like automatic cache collection, automatic refetching etc. will not be if (track) dispatch(unsubscribeMutationResult({ requestId })) }, }) - ret.then(() => { - ret.resolved = true - }) - runningMutations[requestId] = thunkResult - thunkResult.then(() => { + runningMutations[requestId] = ret + ret.then(() => { delete runningMutations[requestId] + ret.resolved = true }) return ret diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index f62e50712f..5683873fb4 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -26,6 +26,7 @@ import { onFocus, onFocusLost, onOnline, onOffline } from './setupListeners' import { buildSlice } from './buildSlice' import { buildMiddleware } from './buildMiddleware' import { buildSelectors } from './buildSelectors' +import type { QueryActionCreatorResult } from './buildInitiate' import { buildInitiate } from './buildInitiate' import { assertCast, safeAssign } from '../tsHelpers' import type { InternalSerializeQueryArgs } from '../defaultSerializeQueryArgs' @@ -264,6 +265,10 @@ declare module '../apiTypes' { : never } getRunningOperationPromises: () => Array> + getRunningQueryPromise: >( + endpointName: EndpointName, + args: QueryArgFrom + ) => QueryActionCreatorResult | undefined } } } @@ -443,14 +448,16 @@ export const coreModule = (): Module => ({ buildInitiateQuery, buildInitiateMutation, getRunningOperationPromises, + getRunningQueryPromise, } = buildInitiate({ queryThunk, mutationThunk, api, serializeQueryArgs: serializeQueryArgs as any, + context, }) - api.getRunningOperationPromises = getRunningOperationPromises + safeAssign(api, { getRunningOperationPromises, getRunningQueryPromise }) return { name: coreModuleName,