diff --git a/src/Query.tsx b/src/Query.tsx index 1bc105fe22..d65d2fa2f4 100644 --- a/src/Query.tsx +++ b/src/Query.tsx @@ -16,7 +16,7 @@ import { getClient } from './component-utils'; import { RenderPromises } from './getDataFromTree'; import isEqual from 'lodash.isequal'; -import shallowEqual from './utils/shallowEqual'; +import shallowEqual, { shallowEqualSansFirst } from './utils/shallowEqual'; import { invariant } from 'ts-invariant'; export type ObservableQueryFields = Pick< @@ -74,6 +74,10 @@ export interface QueryProps extend skip?: boolean; onCompleted?: (data: TData) => void; onError?: (error: ApolloError) => void; + shouldInvalidatePreviousData: ( + nextVariables: TVariables | undefined, + lastVariables: TVariables | undefined, + ) => boolean; } export interface QueryContext { @@ -94,6 +98,7 @@ export default class Query extends static propTypes = { client: PropTypes.object, children: PropTypes.func.isRequired, + shouldInvalidatePreviousData: PropTypes.func, fetchPolicy: PropTypes.string, notifyOnNetworkStatusChange: PropTypes.bool, onCompleted: PropTypes.func, @@ -105,6 +110,13 @@ export default class Query extends partialRefetch: PropTypes.bool, }; + static defaultProps = { + shouldInvalidatePreviousData: ( + nextVariables: OperationVariables | undefined, + lastVariables: OperationVariables | undefined, + ) => !shallowEqualSansFirst(nextVariables, lastVariables), + }; + context: QueryContext | undefined; private client: ApolloClient; @@ -202,6 +214,8 @@ export default class Query extends this.queryObservable = null; this.previousData = {}; this.updateQuery(nextProps); + } else if (nextProps.shouldInvalidatePreviousData(nextProps.variables, this.props.variables)) { + this.previousData = {}; } if (this.props.query !== nextProps.query) { @@ -254,7 +268,7 @@ export default class Query extends ...props, displayName, context: props.context || {}, - metadata: { reactComponent: { displayName }}, + metadata: { reactComponent: { displayName } }, }; } diff --git a/src/utils/shallowEqual.ts b/src/utils/shallowEqual.ts index b35dd0fb5a..5ebc177c05 100644 --- a/src/utils/shallowEqual.ts +++ b/src/utils/shallowEqual.ts @@ -1,3 +1,4 @@ +type AnyObject = { [key: string]: any }; const { hasOwnProperty } = Object.prototype; function is(x: any, y: any) { @@ -7,11 +8,21 @@ function is(x: any, y: any) { return x !== x && y !== y; } -function isObject(obj: any): obj is { [key: string]: any } { +function isObject(obj: any): obj is AnyObject { return obj !== null && typeof obj === "object"; } -export default function shallowEqual(objA: any, objB: any) { +function shallowObjectEqual(objA: AnyObject, objB: AnyObject) { + const keys = Object.keys(objA); + + if (keys.length !== Object.keys(objB).length) { + return false; + } + + return keys.every(key => hasOwnProperty.call(objB, key) && is(objA[key], objB[key])); +} + +export function shallowEqualSansFirst(objA: any, objB: any) { if (is(objA, objB)) { return true; } @@ -20,13 +31,21 @@ export default function shallowEqual(objA: any, objB: any) { return false; } - const keys = Object.keys(objA); + // Ignore `first` key + const { first: _, ...objectA } = objA; + const { first: __, ...objectB } = objB; - if (keys.length !== Object.keys(objB).length) { + return shallowObjectEqual(objectA, objectB); +} + +export default function shallowEqual(objA: any, objB: any) { + if (is(objA, objB)) { + return true; + } + + if (!isObject(objA) || !isObject(objB)) { return false; } - return keys.every( - key => hasOwnProperty.call(objB, key) && is(objA[key], objB[key]), - ); + return shallowObjectEqual(objA, objB); }