diff --git a/src/when.js b/src/when.js index 30e9ef7..5dba5b3 100644 --- a/src/when.js +++ b/src/when.js @@ -84,9 +84,10 @@ class WhenMock { this.nextCallMockId++ - this.fn.mockImplementation((...args) => { - for (let i = 0; i < this.callMocks.length; i++) { - const { matchers, mockImplementation, expectCall, once, called } = this.callMocks[i] + const instance = this + this.fn.mockImplementation(function (...args) { + for (let i = 0; i < instance.callMocks.length; i++) { + const { matchers, mockImplementation, expectCall, once, called } = instance.callMocks[i] // Do not let a once mock match more than once if (once && called) continue @@ -110,16 +111,16 @@ class WhenMock { } if (isMatch && typeof mockImplementation === 'function') { - this.callMocks[i].called = true - return mockImplementation(...args) + instance.callMocks[i].called = true + return mockImplementation.call(this, ...args) } } - if (this._defaultImplementation) { - return this._defaultImplementation(...args) + if (instance._defaultImplementation) { + return instance._defaultImplementation.call(this, ...args) } if (typeof fn.__whenMock__._origMock === 'function') { - return fn.__whenMock__._origMock(...args) + return fn.__whenMock__._origMock.call(this, ...args) } return undefined }) diff --git a/src/when.test.js b/src/when.test.js index 1761fd0..7455553 100644 --- a/src/when.test.js +++ b/src/when.test.js @@ -1025,6 +1025,54 @@ describe('When', () => { expect(fn(2)).toEqual('b') }) + it('keeps call context when not matched', () => { + class TheClass { + call () { + return 'ok' + } + + request (...args) { + return this.call(...args) + } + } + + const theInstance = new TheClass() + + const theSpiedMethod = jest.spyOn(theInstance, 'request') + + when(theSpiedMethod) + .calledWith(1) + .mockReturnValue('mock') + + const unhandledCall = theInstance.request() + expect(unhandledCall).toBe('ok') + }) + + it('keeps call context when matched', () => { + class TheClass { + call () { + return 'ok' + } + + request (...args) { + return this.call(...args) + } + } + + const theInstance = new TheClass() + + const theSpiedMethod = jest.spyOn(theInstance, 'request') + + when(theSpiedMethod) + .calledWith(1) + .mockImplementation(function () { + return this.call() + '!' + }) + + const unhandledCall = theInstance.request(1) + expect(unhandledCall).toBe('ok!') + }) + it('does not add to the number of assertion calls', () => { expect.assertions(0) })