Skip to content

Commit

Permalink
Support React.memo in ReactShallowRenderer (#14816)
Browse files Browse the repository at this point in the history
* Support React.memo in ReactShallowRenderer

ReactShallowRenderer uses element.type frequently, but with React.memo
elements the actual type is element.type.type. This updates
ReactShallowRenderer so it uses the correct element type for Memo
components and also validates the inner props for the wrapped
components.

* Allow Rect.memo to prevent re-renders

* Support memo(forwardRef())

* Dont call memo comparison function on initial render

* Fix test

* Small tweaks
  • Loading branch information
aweary authored and gaearon committed Mar 15, 2019
1 parent ad1cd7e commit 88853f1
Show file tree
Hide file tree
Showing 3 changed files with 1,717 additions and 39 deletions.
125 changes: 86 additions & 39 deletions src/ReactShallowRenderer.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/

import React from 'react';
import {isForwardRef} from 'react-is';
import {isForwardRef, isMemo, ForwardRef} from 'react-is';
import describeComponentFrame from 'shared/describeComponentFrame';
import getComponentName from 'shared/getComponentName';
import shallowEqual from 'shared/shallowEqual';
Expand Down Expand Up @@ -500,7 +500,8 @@ class ReactShallowRenderer {
element.type,
);
invariant(
isForwardRef(element) || typeof element.type === 'function',
isForwardRef(element) ||
(typeof element.type === 'function' || isMemo(element.type)),
'ReactShallowRenderer render(): Shallow rendering works only with custom ' +
'components, but the provided element type was `%s`.',
Array.isArray(element.type)
Expand All @@ -514,22 +515,36 @@ class ReactShallowRenderer {
return;
}

const elementType = isMemo(element.type) ? element.type.type : element.type;
const previousElement = this._element;

this._rendering = true;
this._element = element;
this._context = getMaskedContext(element.type.contextTypes, context);
this._context = getMaskedContext(elementType.contextTypes, context);

// Inner memo component props aren't currently validated in createElement.
if (isMemo(element.type) && elementType.propTypes) {
currentlyValidatingElement = element;
checkPropTypes(
elementType.propTypes,
element.props,
'prop',
getComponentName(elementType),
getStackAddendum,
);
}

if (this._instance) {
this._updateClassComponent(element, this._context);
this._updateClassComponent(elementType, element, this._context);
} else {
if (shouldConstruct(element.type)) {
this._instance = new element.type(
if (shouldConstruct(elementType)) {
this._instance = new elementType(
element.props,
this._context,
this._updater,
);

if (typeof element.type.getDerivedStateFromProps === 'function') {
const partialState = element.type.getDerivedStateFromProps.call(
if (typeof elementType.getDerivedStateFromProps === 'function') {
const partialState = elementType.getDerivedStateFromProps.call(
null,
element.props,
this._instance.state,
Expand All @@ -543,39 +558,59 @@ class ReactShallowRenderer {
}
}

if (element.type.hasOwnProperty('contextTypes')) {
if (elementType.contextTypes) {
currentlyValidatingElement = element;

checkPropTypes(
element.type.contextTypes,
elementType.contextTypes,
this._context,
'context',
getName(element.type, this._instance),
getName(elementType, this._instance),
getStackAddendum,
);

currentlyValidatingElement = null;
}

this._mountClassComponent(element, this._context);
this._mountClassComponent(elementType, element, this._context);
} else {
const prevDispatcher = ReactCurrentDispatcher.current;
ReactCurrentDispatcher.current = this._dispatcher;
this._prepareToUseHooks(element.type);
try {
if (isForwardRef(element)) {
this._rendered = element.type.render(element.props, element.ref);
} else {
this._rendered = element.type.call(
undefined,
element.props,
this._context,
);
let shouldRender = true;
if (
isMemo(element.type) &&
elementType === this._previousComponentIdentity &&
previousElement !== null
) {
// This is a Memo component that is being re-rendered.
const compare = element.type.compare || shallowEqual;
if (compare(previousElement.props, element.props)) {
shouldRender = false;
}
}
if (shouldRender) {
const prevDispatcher = ReactCurrentDispatcher.current;
ReactCurrentDispatcher.current = this._dispatcher;
this._prepareToUseHooks(elementType);
try {
// elementType could still be a ForwardRef if it was
// nested inside Memo.
if (elementType.$$typeof === ForwardRef) {
invariant(
typeof elementType.render === 'function',
'forwardRef requires a render function but was given %s.',
typeof elementType.render,
);
this._rendered = elementType.render.call(
undefined,
element.props,
element.ref,
);
} else {
this._rendered = elementType(element.props, this._context);
}
} finally {
ReactCurrentDispatcher.current = prevDispatcher;
}
} finally {
ReactCurrentDispatcher.current = prevDispatcher;
this._finishHooks(element, context);
}
this._finishHooks(element, context);
}
}

Expand All @@ -601,7 +636,11 @@ class ReactShallowRenderer {
this._instance = null;
}

_mountClassComponent(element: ReactElement, context: null | Object) {
_mountClassComponent(
elementType: Function,
element: ReactElement,
context: null | Object,
) {
this._instance.context = context;
this._instance.props = element.props;
this._instance.state = this._instance.state || null;
Expand All @@ -616,7 +655,7 @@ class ReactShallowRenderer {
// In order to support react-lifecycles-compat polyfilled components,
// Unsafe lifecycles should not be invoked for components using the new APIs.
if (
typeof element.type.getDerivedStateFromProps !== 'function' &&
typeof elementType.getDerivedStateFromProps !== 'function' &&
typeof this._instance.getSnapshotBeforeUpdate !== 'function'
) {
if (typeof this._instance.componentWillMount === 'function') {
Expand All @@ -638,8 +677,12 @@ class ReactShallowRenderer {
// because DOM refs are not available.
}

_updateClassComponent(element: ReactElement, context: null | Object) {
const {props, type} = element;
_updateClassComponent(
elementType: Function,
element: ReactElement,
context: null | Object,
) {
const {props} = element;

const oldState = this._instance.state || emptyObject;
const oldProps = this._instance.props;
Expand All @@ -648,7 +691,7 @@ class ReactShallowRenderer {
// In order to support react-lifecycles-compat polyfilled components,
// Unsafe lifecycles should not be invoked for components using the new APIs.
if (
typeof element.type.getDerivedStateFromProps !== 'function' &&
typeof elementType.getDerivedStateFromProps !== 'function' &&
typeof this._instance.getSnapshotBeforeUpdate !== 'function'
) {
if (typeof this._instance.componentWillReceiveProps === 'function') {
Expand All @@ -664,8 +707,8 @@ class ReactShallowRenderer {

// Read state after cWRP in case it calls setState
let state = this._newState || oldState;
if (typeof type.getDerivedStateFromProps === 'function') {
const partialState = type.getDerivedStateFromProps.call(
if (typeof elementType.getDerivedStateFromProps === 'function') {
const partialState = elementType.getDerivedStateFromProps.call(
null,
props,
state,
Expand All @@ -685,7 +728,10 @@ class ReactShallowRenderer {
state,
context,
);
} else if (type.prototype && type.prototype.isPureReactComponent) {
} else if (
elementType.prototype &&
elementType.prototype.isPureReactComponent
) {
shouldUpdate =
!shallowEqual(oldProps, props) || !shallowEqual(oldState, state);
}
Expand All @@ -694,7 +740,7 @@ class ReactShallowRenderer {
// In order to support react-lifecycles-compat polyfilled components,
// Unsafe lifecycles should not be invoked for components using the new APIs.
if (
typeof element.type.getDerivedStateFromProps !== 'function' &&
typeof elementType.getDerivedStateFromProps !== 'function' &&
typeof this._instance.getSnapshotBeforeUpdate !== 'function'
) {
if (typeof this._instance.componentWillUpdate === 'function') {
Expand Down Expand Up @@ -729,7 +775,8 @@ function getDisplayName(element) {
} else if (typeof element.type === 'string') {
return element.type;
} else {
return element.type.displayName || element.type.name || 'Unknown';
const elementType = isMemo(element.type) ? element.type.type : element.type;
return elementType.displayName || elementType.name || 'Unknown';
}
}

Expand Down
111 changes: 111 additions & 0 deletions src/__tests__/ReactShallowRenderer-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -1454,4 +1454,115 @@ describe('ReactShallowRenderer', () => {
shallowRenderer.render(<Foo foo="bar" />);
expect(logs).toEqual([undefined]);
});

it('should handle memo', () => {
function Foo() {
return <div>foo</div>;
}
const MemoFoo = React.memo(Foo);
const shallowRenderer = createRenderer();
shallowRenderer.render(<MemoFoo />);
});

it('should enable React.memo to prevent a re-render', () => {
const logs = [];
const Foo = React.memo(({count}) => {
logs.push(`Foo: ${count}`);
return <div>{count}</div>;
});
const Bar = React.memo(({count}) => {
logs.push(`Bar: ${count}`);
return <div>{count}</div>;
});
const shallowRenderer = createRenderer();
shallowRenderer.render(<Foo count={1} />);
expect(logs).toEqual(['Foo: 1']);
logs.length = 0;
// Rendering the same element with the same props should be prevented
shallowRenderer.render(<Foo count={1} />);
expect(logs).toEqual([]);
// A different element with the same props should cause a re-render
shallowRenderer.render(<Bar count={1} />);
expect(logs).toEqual(['Bar: 1']);
});

it('should respect a custom comparison function with React.memo', () => {
let renderCount = 0;
function areEqual(props, nextProps) {
return props.foo === nextProps.foo;
}
const Foo = React.memo(({foo, bar}) => {
renderCount++;
return (
<div>
{foo} {bar}
</div>
);
}, areEqual);

const shallowRenderer = createRenderer();
shallowRenderer.render(<Foo foo={1} bar={1} />);
expect(renderCount).toBe(1);
// Change a prop that the comparison funciton ignores
shallowRenderer.render(<Foo foo={1} bar={2} />);
expect(renderCount).toBe(1);
shallowRenderer.render(<Foo foo={2} bar={2} />);
expect(renderCount).toBe(2);
});

it('should not call the comparison function with React.memo on the initial render', () => {
const areEqual = jest.fn(() => false);
const SomeComponent = React.memo(({foo}) => {
return <div>{foo}</div>;
}, areEqual);
const shallowRenderer = createRenderer();
shallowRenderer.render(<SomeComponent foo={1} />);
expect(areEqual).not.toHaveBeenCalled();
expect(shallowRenderer.getRenderOutput()).toEqual(<div>{1}</div>);
});

it('should handle memo(forwardRef())', () => {
const testRef = React.createRef();
const SomeComponent = React.forwardRef((props, ref) => {
expect(ref).toEqual(testRef);
return (
<div>
<span className="child1" />
<span className="child2" />
</div>
);
});

const SomeMemoComponent = React.memo(SomeComponent);

const shallowRenderer = createRenderer();
const result = shallowRenderer.render(<SomeMemoComponent ref={testRef} />);

expect(result.type).toBe('div');
expect(result.props.children).toEqual([
<span className="child1" />,
<span className="child2" />,
]);
});

it('should warn for forwardRef(memo())', () => {
const testRef = React.createRef();
const SomeMemoComponent = React.memo(({foo}) => {
return <div>{foo}</div>;
});
const shallowRenderer = createRenderer();
expect(() => {
expect(() => {
const SomeComponent = React.forwardRef(SomeMemoComponent);
shallowRenderer.render(<SomeComponent ref={testRef} />);
}).toWarnDev(
'Warning: forwardRef requires a render function but received ' +
'a `memo` component. Instead of forwardRef(memo(...)), use ' +
'memo(forwardRef(...))',
{withoutStack: true},
);
}).toThrowError(
'forwardRef requires a render function but was given object.',
);
});
});
Loading

0 comments on commit 88853f1

Please sign in to comment.