diff --git a/src/disposeOnUnmount.js b/src/disposeOnUnmount.js index 16012fb4..c74fc543 100644 --- a/src/disposeOnUnmount.js +++ b/src/disposeOnUnmount.js @@ -6,7 +6,7 @@ const storeKey = newSymbol("disposeOnUnmount") function runDisposersOnWillUnmount() { if (!this[storeKey]) { // when disposeOnUnmount is only set to some instances of a component it will still patch the prototype - return; + return } this[storeKey].forEach(propKeyOrFunction => { const prop = @@ -46,7 +46,7 @@ export function disposeOnUnmount(target, propertyKeyOrFunction) { // tweak the component class componentWillUnmount if not done already if (!componentWasAlreadyModified) { - patch(target, "componentWillUnmount", runDisposersOnWillUnmount, false) + patch(target, "componentWillUnmount", runDisposersOnWillUnmount) } // return the disposer as is if invoked as a non decorator diff --git a/src/observer.js b/src/observer.js index bf97930a..3d9f79bb 100644 --- a/src/observer.js +++ b/src/observer.js @@ -91,8 +91,8 @@ export const errorsReporter = new EventEmitter() * Utilities */ -function patch(target, funcName, runMixinFirst = false) { - newPatch(target, funcName, reactiveMixin[funcName], runMixinFirst) +function patch(target, funcName) { + newPatch(target, funcName, reactiveMixin[funcName]) } function shallowEqual(objA, objB) { diff --git a/src/utils/utils.js b/src/utils/utils.js index abcbb5af..10880ce4 100644 --- a/src/utils/utils.js +++ b/src/utils/utils.js @@ -15,100 +15,94 @@ export function newSymbol(name) { } const mobxMixins = newSymbol("patchMixins") -const mobxMixin = newSymbol("patchMixin") +const mobxPatchedDefinition = newSymbol("patchedDefinition") +const mobxRealMethod = newSymbol("patchRealMethod") -function getCreateMixins(target, methodName) { +function getMixins(target, methodName) { const mixins = (target[mobxMixins] = target[mobxMixins] || {}) const methodMixins = (mixins[methodName] = mixins[methodName] || {}) - methodMixins.pre = methodMixins.pre || [] - methodMixins.post = methodMixins.post || [] + methodMixins.locks = methodMixins.locks || 0 + methodMixins.methods = methodMixins.methods || [] return methodMixins } -function getMixins(target, methodName) { - return target[mobxMixins][methodName] -} - -const cachedDefinitions = {} - -function createOrGetCachedDefinition(methodName, enumerable) { - const cacheKey = `${methodName}+${enumerable}` - const cached = cachedDefinitions[cacheKey] - if (cached) { - return cached - } - - const wrapperMethod = function wrapperMethod(...args) { - const mixins = getMixins(this, methodName) - - // avoid possible recursive calls by custom patches - if (mixins.realRunning) { - return - } - mixins.realRunning = true +function wrapper(realMethod, mixins, ...args) { + // locks are used to ensure that mixins are invoked only once per invocation, even on recursive calls + mixins.locks++ - const realMethod = mixins.real + try { let retVal + if (realMethod !== undefined && realMethod !== null) { + retVal = realMethod.apply(this, args) + } - try { - mixins.pre.forEach(pre => { - pre.apply(this, args) - }) - - if (realMethod !== undefined && realMethod !== null) { - retVal = realMethod.apply(this, args) - } - - mixins.post.forEach(post => { - post.apply(this, args) + return retVal + } finally { + mixins.locks-- + if (mixins.locks === 0) { + mixins.methods.forEach(mx => { + mx.apply(this, args) }) - - return retVal - } finally { - mixins.realRunning = false } } - wrapperMethod[mobxMixin] = true +} - const newDefinition = { - get() { - return wrapperMethod - }, - set(value) { - const mixins = getMixins(this, methodName) - mixins.real = value - }, - configurable: true, - enumerable: enumerable +function wrapFunction(mixins) { + const fn = function(...args) { + wrapper.call(this, fn[mobxRealMethod], mixins, ...args) } - - cachedDefinitions[cacheKey] = newDefinition - - return newDefinition + return fn } -export function patch(target, methodName, mixinMethod, runMixinFirst = false) { - const mixins = getCreateMixins(target, methodName) +export function patch(target, methodName, ...mixinMethods) { + const mixins = getMixins(target, methodName) - if (runMixinFirst) { - mixins.pre.unshift(mixinMethod) - } else { - mixins.post.push(mixinMethod) + for (const mixinMethod of mixinMethods) { + if (mixins.methods.indexOf(mixinMethod) < 0) { + mixins.methods.push(mixinMethod) + } } - const realMethod = target[methodName] - if (typeof realMethod === "function" && realMethod[mobxMixin]) { - // already patched, do not repatch + const oldDefinition = Object.getOwnPropertyDescriptor(target, methodName) + if (oldDefinition && oldDefinition[mobxPatchedDefinition]) { + // already patched definition, do not repatch return } - mixins.real = realMethod - - const oldDefinition = Object.getOwnPropertyDescriptor(target, methodName) - const newDefinition = createOrGetCachedDefinition( + const originalMethod = target[methodName] + const newDefinition = createDefinition( + target, methodName, - oldDefinition ? oldDefinition.enumerable : undefined + oldDefinition ? oldDefinition.enumerable : undefined, + mixins, + originalMethod ) Object.defineProperty(target, methodName, newDefinition) } + +function createDefinition(target, methodName, enumerable, mixins, originalMethod) { + const wrappedFunc = wrapFunction(mixins) + wrappedFunc[mobxRealMethod] = originalMethod + + return { + [mobxPatchedDefinition]: true, + get: function() { + return wrappedFunc + }, + set: function(value) { + if (this === target) { + wrappedFunc[mobxRealMethod] = value + } else { + // when it is an instance of the prototype/a child prototype patch that particular case again separately + // since we need to store separate values depending on wether it is the actual instance, the prototype, etc + // e.g. the method for super might not be the same as the method for the prototype which might be not the same + // as the method for the instance + const newDefinition = createDefinition(this, methodName, enumerable, mixins, value) + Object.defineProperty(this, methodName, newDefinition) + } + }, + configurable: true, + enumerable: enumerable + } +} diff --git a/test/disposeOnUnmount.test.js b/test/disposeOnUnmount.test.js index 1d1c1e1e..1aed6f24 100644 --- a/test/disposeOnUnmount.test.js +++ b/test/disposeOnUnmount.test.js @@ -17,8 +17,8 @@ async function testComponent(C, afterMount, afterUnmount) { await asyncReactDOMRender(null, testRoot) - expect(cref.methodA).toHaveBeenCalled() - expect(cref.methodB).toHaveBeenCalled() + expect(cref.methodA).toHaveBeenCalledTimes(1) + expect(cref.methodB).toHaveBeenCalledTimes(1) if (afterUnmount) { afterUnmount(cref) } @@ -134,8 +134,8 @@ describe("without observer", () => { expect(methodD).not.toHaveBeenCalled() }, () => { - expect(methodC).toHaveBeenCalled() - expect(methodD).toHaveBeenCalled() + expect(methodC).toHaveBeenCalledTimes(1) + expect(methodD).toHaveBeenCalledTimes(1) } ) }) @@ -256,8 +256,8 @@ describe("with observer", () => { expect(methodD).not.toHaveBeenCalled() }, () => { - expect(methodC).toHaveBeenCalled() - expect(methodD).toHaveBeenCalled() + expect(methodC).toHaveBeenCalledTimes(1) + expect(methodD).toHaveBeenCalledTimes(1) } ) }) @@ -347,8 +347,91 @@ test("custom patching should work", async () => { ) }) +describe("super calls should work", async () => { + async function doTest(baseObserver, cObserver) { + const events = [] + + const sharedMethod = jest.fn() + + class BaseComponent extends React.Component { + @disposeOnUnmount + method0 = sharedMethod + + @disposeOnUnmount + methodA = jest.fn() + + componentDidMount() { + events.push("baseDidMount") + } + + componentWillUnmount() { + events.push("baseWillUnmount") + } + } + + class C extends BaseComponent { + @disposeOnUnmount + method0 = sharedMethod + + @disposeOnUnmount + methodB = jest.fn() + + componentDidMount() { + super.componentDidMount() + events.push("CDidMount") + } + + componentWillUnmount() { + super.componentWillUnmount() + events.push("CWillUnmount") + } + + render() { + return null + } + } + + if (baseObserver) { + BaseComponent = observer(BaseComponent) + } + if (cObserver) { + C = observer(C) + } + + await testComponent( + C, + ref => { + expect(events).toEqual(["baseDidMount", "CDidMount"]) + expect(sharedMethod).toHaveBeenCalledTimes(0) + }, + ref => { + expect(events).toEqual([ + "baseDidMount", + "CDidMount", + "baseWillUnmount", + "CWillUnmount" + ]) + expect(sharedMethod).toHaveBeenCalledTimes(2) + } + ) + } + + it("none is observer", async () => { + await doTest(false, false) + }) + it("base is observer", async () => { + await doTest(true, false) + }) + it("C is observer", async () => { + await doTest(false, true) + }) + it("both observers", async () => { + await doTest(true, true) + }) +}) + it("componentDidMount should be different between components", async () => { - async function test(withObserver) { + async function doTest(withObserver) { const events = [] class A extends React.Component { @@ -417,6 +500,6 @@ it("componentDidMount should be different between components", async () => { expect(events).toEqual(["mountA", "unmountA", "mountB", "unmountB"]) } - await test(true) - await test(false) + await doTest(true) + await doTest(false) }) diff --git a/test/patch.test.js b/test/patch.test.js new file mode 100644 index 00000000..5ed1adf6 --- /dev/null +++ b/test/patch.test.js @@ -0,0 +1,277 @@ +import * as React from "react" +import { createTestRoot, asyncReactDOMRender } from "./" +import { patch } from "../src/utils/utils" + +const testRoot = createTestRoot() + +async function testComponent(C, didMountMixin, willUnmountMixin, doMixinTest = true) { + if (doMixinTest) { + expect(didMountMixin).not.toHaveBeenCalled() + expect(willUnmountMixin).not.toHaveBeenCalled() + } + + await asyncReactDOMRender(, testRoot) + + if (doMixinTest) { + expect(didMountMixin).toHaveBeenCalledTimes(1) + expect(willUnmountMixin).not.toHaveBeenCalled() + } + + await asyncReactDOMRender(null, testRoot) + + if (doMixinTest) { + expect(didMountMixin).toHaveBeenCalledTimes(1) + expect(willUnmountMixin).toHaveBeenCalledTimes(1) + } +} + +test("no overrides", async () => { + const cdm = jest.fn() + const cwu = jest.fn() + class C extends React.Component { + render() { + return null + } + } + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + + await testComponent(C, cdm, cwu) +}) + +test("prototype overrides", async () => { + const cdm = jest.fn() + const cwu = jest.fn() + let cdmCalls = 0 + let cwuCalls = 0 + class C extends React.Component { + componentDidMount() { + cdmCalls++ + } + componentWillUnmount() { + cwuCalls++ + } + render() { + return null + } + } + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + + await testComponent(C, cdm, cwu) + expect(cdmCalls).toBe(1) + expect(cwuCalls).toBe(1) +}) + +test("arrow function overrides", async () => { + const cdm = jest.fn() + const cwu = jest.fn() + let cdmCalls = 0 + let cwuCalls = 0 + class C extends React.Component { + componentDidMount = () => { + cdmCalls++ + } + componentWillUnmount = () => { + cwuCalls++ + } + render() { + return null + } + } + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + + await testComponent(C, cdm, cwu) + expect(cdmCalls).toBe(1) + expect(cwuCalls).toBe(1) +}) + +test("recursive calls", async () => { + const cdm = jest.fn() + const cwu = jest.fn() + let cdmCalls = 0 + let cwuCalls = 0 + class C extends React.Component { + componentDidMount() { + cdmCalls++ + while (cdmCalls < 10) { + this.componentDidMount() + } + } + componentWillUnmount() { + cwuCalls++ + while (cwuCalls < 10) { + this.componentWillUnmount() + } + } + render() { + return null + } + } + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + + await testComponent(C, cdm, cwu) + expect(cdmCalls).toBe(10) + expect(cwuCalls).toBe(10) +}) + +test("prototype + arrow function overrides", async () => { + const cdm = jest.fn() + const cwu = jest.fn() + let cdmCalls = 0 + let cwuCalls = 0 + class C extends React.Component { + componentDidMount() { + cdmCalls++ + } + componentWillUnmount() { + cwuCalls++ + } + render() { + return null + } + constructor(props) { + super(props) + this.componentDidMount = () => { + cdmCalls++ + } + this.componentWillUnmount = () => { + cwuCalls++ + } + } + } + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + + await testComponent(C, cdm, cwu) + expect(cdmCalls).toBe(1) + expect(cwuCalls).toBe(1) +}) + +describe("inheritance with prototype methods", async () => { + async function doTest(patchBase, patchOther, callSuper) { + const cdm = jest.fn() + const cwu = jest.fn() + let cdmCalls = 0 + let cwuCalls = 0 + + class B extends React.Component { + componentDidMount() { + cdmCalls++ + } + componentWillUnmount() { + cwuCalls++ + } + } + + class C extends B { + componentDidMount() { + if (callSuper) { + super.componentDidMount() + } + cdmCalls++ + } + componentWillUnmount() { + if (callSuper) { + super.componentWillUnmount() + } + cwuCalls++ + } + render() { + return null + } + } + + if (patchBase) { + patch(B.prototype, "componentDidMount", cdm) + patch(B.prototype, "componentWillUnmount", cwu) + } + if (patchOther) { + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + } + + await testComponent(C, cdm, cwu, patchBase || patchOther) + expect(cdmCalls).toBe(callSuper ? 2 : 1) + expect(cwuCalls).toBe(callSuper ? 2 : 1) + } + + for (const base of [false, true]) { + for (const other of [false, true]) { + for (const callSuper of [false, true]) { + test(`base: ${base}, other: ${other}, callSuper: ${callSuper}`, async () => { + if (base && !other && !callSuper) { + // this one is expected to fail, since we are patching only the base and the other one totally ignores the base method + try { + await doTest(base, other, callSuper) + fail("should have failed") + } catch (e) {} + } else { + await doTest(base, other, callSuper) + } + }) + } + } + } +}) + +describe("inheritance with arrow functions", async () => { + async function doTest(patchBase, patchOther, callSuper) { + const cdm = jest.fn() + const cwu = jest.fn() + let cdmCalls = 0 + let cwuCalls = 0 + + class B extends React.Component { + componentDidMount() { + cdmCalls++ + } + componentWillUnmount() { + cwuCalls++ + } + } + + class C extends B { + componentDidMount = () => { + if (callSuper) { + super.componentDidMount() + } + cdmCalls++ + } + componentWillUnmount = () => { + if (callSuper) { + super.componentWillUnmount() + } + cwuCalls++ + } + render() { + return null + } + } + + if (patchBase) { + patch(B.prototype, "componentDidMount", cdm) + patch(B.prototype, "componentWillUnmount", cwu) + } + if (patchOther) { + patch(C.prototype, "componentDidMount", cdm) + patch(C.prototype, "componentWillUnmount", cwu) + } + + await testComponent(C, cdm, cwu, patchBase || patchOther) + expect(cdmCalls).toBe(callSuper ? 2 : 1) + expect(cwuCalls).toBe(callSuper ? 2 : 1) + } + + for (const base of [false, true]) { + for (const other of [false, true]) { + for (const callSuper of [false, true]) { + test(`base: ${base}, other: ${other}, callSuper: ${callSuper}`, async () => { + await doTest(base, other, callSuper) + }) + } + } + } +})