From 590c9dbd31c4b5e34394a397716145242340170d Mon Sep 17 00:00:00 2001 From: jiexi Date: Mon, 8 Apr 2024 15:38:02 -0700 Subject: [PATCH] QueuedRequestController: Fix list of methods that should have requests enqueued and/or switch the globally selected network (#4066) ## Explanation Previously we were not properly enqueuing requests that could trigger a confirmation notification as well as not switching networks for methods that either require it or would exhibit unexpected UI/UX without doing so beforehand. This PR improves the specificity of which methods should be handled differently in the queued request flow which covers many methods not previously handled correctly. ## References See: https://github.com/MetaMask/metamask-extension/pull/22865#issuecomment-1996054533 ## Changelog ### `@metamask/queued-request-controller` #### Added - **BREAKING**: The `QueuedRequestMiddleware` constructor now requires the `methodsWithConfirmation` param which should be a list of methods that can trigger confirmations ([#4066](https://github.com/MetaMask/core/pull/4066)) - **BREAKING**: The `QueuedRequestController` constructor now requires the `methodsRequiringNetworkSwitch` param which should be a list of methods that need the globally selected network to switched to the dapp selected network before being processed ([#4066](https://github.com/MetaMask/core/pull/4066)) #### Changed - **BREAKING**: `QueuedRequestController.enqueueRequest()` now ensures the globally selected network matches the dapp selected network before processing methods listed in the `methodsRequiringNetworkSwitch` constructor param. This replaces the previous behavior of switching for all methods except `eth_requestAccounts`. ([#4066](https://github.com/MetaMask/core/pull/4066)) ## Checklist - [x] I've updated the test suite for new or updated code as appropriate - [x] I've updated documentation (JSDoc, Markdown, etc.) for new or updated code as appropriate - [x] I've highlighted breaking changes using the "BREAKING" category above as appropriate --- .../queued-request-controller/CHANGELOG.md | 9 + .../src/QueuedRequestController.test.ts | 183 ++++++++++-------- .../src/QueuedRequestController.ts | 21 +- .../src/QueuedRequestMiddleware.test.ts | 158 +++++++-------- .../src/QueuedRequestMiddleware.ts | 22 +-- 5 files changed, 209 insertions(+), 184 deletions(-) diff --git a/packages/queued-request-controller/CHANGELOG.md b/packages/queued-request-controller/CHANGELOG.md index 39cd85cf00..a450470797 100644 --- a/packages/queued-request-controller/CHANGELOG.md +++ b/packages/queued-request-controller/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **BREAKING**: The `QueuedRequestMiddleware` constructor now requires the `methodsWithConfirmation` param which should be a list of methods that can trigger confirmations ([#4066](https://github.com/MetaMask/core/pull/4066)) +- **BREAKING**: The `QueuedRequestController` constructor now requires the `methodsRequiringNetworkSwitch` param which should be a list of methods that need the globally selected network to switched to the dapp selected network before being processed ([#4066](https://github.com/MetaMask/core/pull/4066)) + +### Changed + +- **BREAKING**: `QueuedRequestController.enqueueRequest()` now ensures the globally selected network matches the dapp selected network before processing methods listed in the `methodsRequiringNetworkSwitch` constructor param. This replaces the previous behavior of switching for all methods except `eth_requestAccounts`. ([#4066](https://github.com/MetaMask/core/pull/4066)) + ## [0.7.0] ### Changed diff --git a/packages/queued-request-controller/src/QueuedRequestController.test.ts b/packages/queued-request-controller/src/QueuedRequestController.test.ts index e00f44311a..80bcc9f8cf 100644 --- a/packages/queued-request-controller/src/QueuedRequestController.test.ts +++ b/packages/queued-request-controller/src/QueuedRequestController.test.ts @@ -25,6 +25,7 @@ describe('QueuedRequestController', () => { it('can be instantiated with default values', () => { const options: QueuedRequestControllerOptions = { messenger: buildQueuedRequestControllerMessenger(), + methodsRequiringNetworkSwitch: [], }; const controller = new QueuedRequestController(options); @@ -33,10 +34,7 @@ describe('QueuedRequestController', () => { describe('enqueueRequest', () => { it('skips the queue if the queue is empty and no request is being processed', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); await controller.enqueueRequest(buildRequest(), async () => { expect(controller.state.queuedRequestCount).toBe(0); @@ -45,10 +43,7 @@ describe('QueuedRequestController', () => { }); it('skips the queue if the queue is empty and the request being processed has the same origin', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); // Trigger first request const firstRequest = controller.enqueueRequest( buildRequest(), @@ -65,7 +60,7 @@ describe('QueuedRequestController', () => { await firstRequest; }); - it('switches network if a request comes in for a different network client', async () => { + it('switches network if a request comes in for a different network client and the method is in the methodsRequiringNetworkSwitch param', async () => { const mockSetActiveNetwork = jest.fn(); const { messenger } = buildControllerMessenger({ networkControllerGetState: jest.fn().mockReturnValue({ @@ -82,13 +77,13 @@ describe('QueuedRequestController', () => { 'QueuedRequestController:networkSwitched', onNetworkSwitched, ); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + methodsRequiringNetworkSwitch: ['method_requiring_network_switch'], + }); await controller.enqueueRequest( - buildRequest(), + { ...buildRequest(), method: 'method_requiring_network_switch' }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); @@ -100,7 +95,7 @@ describe('QueuedRequestController', () => { ); }); - it('does not switch networks if the method is `eth_requestAccounts`', async () => { + it('does not switch networks if the method is not in the methodsRequiringNetworkSwitch param', async () => { const mockSetActiveNetwork = jest.fn(); const { messenger } = buildControllerMessenger({ networkControllerGetState: jest.fn().mockReturnValue({ @@ -117,13 +112,13 @@ describe('QueuedRequestController', () => { 'QueuedRequestController:networkSwitched', onNetworkSwitched, ); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + methodsRequiringNetworkSwitch: [], + }); await controller.enqueueRequest( - { ...buildRequest(), method: 'eth_requestAccounts' }, + { ...buildRequest(), method: 'not_in_methodsRequiringNetworkSwitch' }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); @@ -148,10 +143,9 @@ describe('QueuedRequestController', () => { 'QueuedRequestController:networkSwitched', onNetworkSwitched, ); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + }); await controller.enqueueRequest( buildRequest(), @@ -163,10 +157,7 @@ describe('QueuedRequestController', () => { }); it('queues request if a request from another origin is being processed', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); // Trigger first request const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://exampleorigin1.metamask.io' }, @@ -189,10 +180,7 @@ describe('QueuedRequestController', () => { }); it('drains batch from queue when current batch finishes', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); // Trigger first batch const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://firstbatch.metamask.io' }, @@ -236,10 +224,7 @@ describe('QueuedRequestController', () => { }); it('drains batch from queue when current batch finishes with requests out-of-order', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); // Trigger first batch const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://firstbatch.metamask.io' }, @@ -283,10 +268,7 @@ describe('QueuedRequestController', () => { }); it('processes requests from each batch in parallel', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, async () => { @@ -342,10 +324,7 @@ describe('QueuedRequestController', () => { }); it('preserves request order within each batch', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); const executionOrder: string[] = []; const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, @@ -398,10 +377,7 @@ describe('QueuedRequestController', () => { }); it('preserves request order even when interlaced with requests from other origins', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); const executionOrder: string[] = []; const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, @@ -461,10 +437,9 @@ describe('QueuedRequestController', () => { 'QueuedRequestController:networkSwitched', onNetworkSwitched, ); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + }); const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, () => new Promise((resolve) => setTimeout(resolve, 10)), @@ -513,10 +488,9 @@ describe('QueuedRequestController', () => { 'QueuedRequestController:networkSwitched', onNetworkSwitched, ); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + }); const firstRequest = controller.enqueueRequest( { ...buildRequest(), origin: 'firstorigin.metamask.io' }, () => new Promise((resolve) => setTimeout(resolve, 10)), @@ -558,14 +532,18 @@ describe('QueuedRequestController', () => { .fn() .mockImplementation((_origin) => 'differentNetworkClientId'), }); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + methodsRequiringNetworkSwitch: ['method_requiring_network_switch'], + }); await expect(() => controller.enqueueRequest( - { ...buildRequest(), origin: 'https://example.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://example.metamask.io', + }, jest.fn(), ), ).rejects.toThrow(switchError); @@ -589,12 +567,16 @@ describe('QueuedRequestController', () => { : 'selectedNetworkClientId', ), }); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + methodsRequiringNetworkSwitch: ['method_requiring_network_switch'], + }); const firstRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://firstorigin.metamask.io', + }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); // ensure first request skips queue @@ -605,7 +587,11 @@ describe('QueuedRequestController', () => { () => new Promise((resolve) => setTimeout(resolve, 100)), ); const secondRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://secondorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://secondorigin.metamask.io', + }, secondRequestNext, ); @@ -635,12 +621,16 @@ describe('QueuedRequestController', () => { : 'selectedNetworkClientId', ), }); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + methodsRequiringNetworkSwitch: ['method_requiring_network_switch'], + }); const firstRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://firstorigin.metamask.io', + }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); // ensure first request skips queue @@ -651,7 +641,11 @@ describe('QueuedRequestController', () => { () => new Promise((resolve) => setTimeout(resolve, 100)), ); const secondRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://secondorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://secondorigin.metamask.io', + }, secondRequestNext, ); // ensure test starts with one request queued up @@ -680,12 +674,16 @@ describe('QueuedRequestController', () => { : 'selectedNetworkClientId', ), }); - const options: QueuedRequestControllerOptions = { + const controller = buildQueuedRequestController({ messenger: buildQueuedRequestControllerMessenger(messenger), - }; - const controller = new QueuedRequestController(options); + methodsRequiringNetworkSwitch: ['method_requiring_network_switch'], + }); const firstRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://firstorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://firstorigin.metamask.io', + }, () => new Promise((resolve) => setTimeout(resolve, 10)), ); // ensure first request skips queue @@ -696,7 +694,11 @@ describe('QueuedRequestController', () => { () => new Promise((resolve) => setTimeout(resolve, 100)), ); const secondRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://secondorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://secondorigin.metamask.io', + }, secondRequestNext, ); const thirdRequestNext = jest @@ -705,7 +707,11 @@ describe('QueuedRequestController', () => { () => new Promise((resolve) => setTimeout(resolve, 100)), ); const thirdRequest = controller.enqueueRequest( - { ...buildRequest(), origin: 'https://thirdorigin.metamask.io' }, + { + ...buildRequest(), + method: 'method_requiring_network_switch', + origin: 'https://thirdorigin.metamask.io', + }, thirdRequestNext, ); // ensure test starts with two requests queued up @@ -722,11 +728,7 @@ describe('QueuedRequestController', () => { describe('when a request fails', () => { it('throws error', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); // Mock a request that throws an error const requestWithError = jest.fn(() => @@ -744,10 +746,7 @@ describe('QueuedRequestController', () => { }); it('correctly updates the request queue count upon failure', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); await expect(() => controller.enqueueRequest( @@ -761,11 +760,7 @@ describe('QueuedRequestController', () => { }); it('correctly processes the next item in the queue', async () => { - const options: QueuedRequestControllerOptions = { - messenger: buildQueuedRequestControllerMessenger(), - }; - - const controller = new QueuedRequestController(options); + const controller = buildQueuedRequestController(); // Mock requests with one request throwing an error const request1 = jest.fn(async () => { @@ -900,6 +895,24 @@ function buildQueuedRequestControllerMessenger( }); } +/** + * Builds a QueuedRequestController + * + * @param overrideOptions - The optional options object. + * @returns The QueuedRequestController. + */ +function buildQueuedRequestController( + overrideOptions?: Partial, +): QueuedRequestController { + const options: QueuedRequestControllerOptions = { + messenger: buildQueuedRequestControllerMessenger(), + methodsRequiringNetworkSwitch: [], + ...overrideOptions, + }; + + return new QueuedRequestController(options); +} + /** * Build a valid JSON-RPC request that includes all required properties * diff --git a/packages/queued-request-controller/src/QueuedRequestController.ts b/packages/queued-request-controller/src/QueuedRequestController.ts index 63c591be4b..a371c13ac3 100644 --- a/packages/queued-request-controller/src/QueuedRequestController.ts +++ b/packages/queued-request-controller/src/QueuedRequestController.ts @@ -73,6 +73,7 @@ export type QueuedRequestControllerMessenger = RestrictedControllerMessenger< export type QueuedRequestControllerOptions = { messenger: QueuedRequestControllerMessenger; + methodsRequiringNetworkSwitch: string[]; }; /** @@ -127,13 +128,27 @@ export class QueuedRequestController extends BaseController< */ #processingRequestCount = 0; + /** + * This is a list of methods that require the globally selected network + * to match the dapp selected network before being processed. These can + * be for UI/UX reasons where the currently selected network is displayed + * in the confirmation even though it will be submitted on the correct + * network for the dapp. It could also be that a method expects the + * globally selected network to match some value in the request params itself. + */ + readonly #methodsRequiringNetworkSwitch: string[]; + /** * Construct a QueuedRequestController. * * @param options - Controller options. * @param options.messenger - The restricted controller messenger that facilitates communication with other controllers. + * @param options.methodsRequiringNetworkSwitch - A list of methods that require the globally selected network to match the dapp selected network. */ - constructor({ messenger }: QueuedRequestControllerOptions) { + constructor({ + messenger, + methodsRequiringNetworkSwitch, + }: QueuedRequestControllerOptions) { super({ name: controllerName, metadata: { @@ -145,6 +160,7 @@ export class QueuedRequestController extends BaseController< messenger, state: { queuedRequestCount: 0 }, }); + this.#methodsRequiringNetworkSwitch = methodsRequiringNetworkSwitch; this.#registerMessageHandlers(); } @@ -282,10 +298,9 @@ export class QueuedRequestController extends BaseController< this.#updateQueuedRequestCount(); await waitForDequeue; - } else if (request.method !== 'eth_requestAccounts') { + } else if (this.#methodsRequiringNetworkSwitch.includes(request.method)) { // Process request immediately // Requires switching network now if necessary - // Note: we dont need to switch chain before processing eth_requestAccounts because accounts are not network-specific (at the time of writing) await this.#switchNetworkIfNecessary(); } this.#processingRequestCount += 1; diff --git a/packages/queued-request-controller/src/QueuedRequestMiddleware.test.ts b/packages/queued-request-controller/src/QueuedRequestMiddleware.test.ts index f8b1b5dead..c062c70484 100644 --- a/packages/queued-request-controller/src/QueuedRequestMiddleware.test.ts +++ b/packages/queued-request-controller/src/QueuedRequestMiddleware.test.ts @@ -5,37 +5,9 @@ import type { QueuedRequestControllerEnqueueRequestAction } from './QueuedReques import { createQueuedRequestMiddleware } from './QueuedRequestMiddleware'; import type { QueuedRequestMiddlewareJsonRpcRequest } from './types'; -const getRequestDefaults = (): QueuedRequestMiddlewareJsonRpcRequest => { - return { - method: 'doesnt matter', - id: 'doesnt matter', - jsonrpc: '2.0' as const, - origin: 'example.com', - networkClientId: 'mainnet', - }; -}; - -const getPendingResponseDefault = (): PendingJsonRpcResponse => { - return { - id: 'doesnt matter', - jsonrpc: '2.0' as const, - }; -}; - -const getMockEnqueueRequest = () => - jest - .fn< - ReturnType, - Parameters - >() - .mockImplementation((_request, requestNext) => requestNext()); - describe('createQueuedRequestMiddleware', () => { it('throws if not provided an origin', async () => { - const middleware = createQueuedRequestMiddleware({ - enqueueRequest: getMockEnqueueRequest(), - useRequestQueue: () => false, - }); + const middleware = buildQueuedRequestMiddleware(); const request = getRequestDefaults(); // @ts-expect-error Intentionally invalid request delete request.origin; @@ -49,10 +21,7 @@ describe('createQueuedRequestMiddleware', () => { }); it('throws if provided an invalid origin', async () => { - const middleware = createQueuedRequestMiddleware({ - enqueueRequest: getMockEnqueueRequest(), - useRequestQueue: () => false, - }); + const middleware = buildQueuedRequestMiddleware(); const request = getRequestDefaults(); // @ts-expect-error Intentionally invalid request request.origin = 1; @@ -66,10 +35,7 @@ describe('createQueuedRequestMiddleware', () => { }); it('throws if not provided an networkClientId', async () => { - const middleware = createQueuedRequestMiddleware({ - enqueueRequest: getMockEnqueueRequest(), - useRequestQueue: () => false, - }); + const middleware = buildQueuedRequestMiddleware(); const request = getRequestDefaults(); // @ts-expect-error Intentionally invalid request delete request.networkClientId; @@ -83,10 +49,7 @@ describe('createQueuedRequestMiddleware', () => { }); it('throws if provided an invalid networkClientId', async () => { - const middleware = createQueuedRequestMiddleware({ - enqueueRequest: getMockEnqueueRequest(), - useRequestQueue: () => false, - }); + const middleware = buildQueuedRequestMiddleware(); const request = getRequestDefaults(); // @ts-expect-error Intentionally invalid request request.networkClientId = 1; @@ -103,9 +66,8 @@ describe('createQueuedRequestMiddleware', () => { it('does not enqueue the request when useRequestQueue is false', async () => { const mockEnqueueRequest = getMockEnqueueRequest(); - const middleware = createQueuedRequestMiddleware({ + const middleware = buildQueuedRequestMiddleware({ enqueueRequest: mockEnqueueRequest, - useRequestQueue: () => false, }); await new Promise((resolve, reject) => @@ -122,7 +84,7 @@ describe('createQueuedRequestMiddleware', () => { it('does not enqueue request that has no confirmation', async () => { const mockEnqueueRequest = getMockEnqueueRequest(); - const middleware = createQueuedRequestMiddleware({ + const middleware = buildQueuedRequestMiddleware({ enqueueRequest: mockEnqueueRequest, useRequestQueue: () => true, }); @@ -139,38 +101,17 @@ describe('createQueuedRequestMiddleware', () => { expect(mockEnqueueRequest).not.toHaveBeenCalled(); }); - it('enqueues request that has a confirmation', async () => { + it('enqueues the request if the method is in the methodsWithConfirmation param', async () => { const mockEnqueueRequest = getMockEnqueueRequest(); - const middleware = createQueuedRequestMiddleware({ + const middleware = buildQueuedRequestMiddleware({ enqueueRequest: mockEnqueueRequest, useRequestQueue: () => true, + methodsWithConfirmation: ['method_with_confirmation'], }); const request = { ...getRequestDefaults(), origin: 'exampleorigin.com', - method: 'eth_sendTransaction', - }; - - await new Promise((resolve, reject) => - middleware(request, getPendingResponseDefault(), resolve, reject), - ); - - expect(mockEnqueueRequest).toHaveBeenCalledWith( - request, - expect.any(Function), - ); - }); - - it('enqueues request that have a confirmation', async () => { - const mockEnqueueRequest = getMockEnqueueRequest(); - const middleware = createQueuedRequestMiddleware({ - enqueueRequest: mockEnqueueRequest, - useRequestQueue: () => true, - }); - const request = { - ...getRequestDefaults(), - origin: 'exampleorigin.com', - method: 'eth_sendTransaction', + method: 'method_with_confirmation', }; await new Promise((resolve, reject) => @@ -184,10 +125,7 @@ describe('createQueuedRequestMiddleware', () => { }); it('calls next when a request is not queued', async () => { - const middleware = createQueuedRequestMiddleware({ - enqueueRequest: getMockEnqueueRequest(), - useRequestQueue: () => false, - }); + const middleware = buildQueuedRequestMiddleware(); const mockNext = jest.fn(); await new Promise((resolve) => { @@ -204,7 +142,7 @@ describe('createQueuedRequestMiddleware', () => { }); it('calls next after a request is queued and processed', async () => { - const middleware = createQueuedRequestMiddleware({ + const middleware = buildQueuedRequestMiddleware({ enqueueRequest: getMockEnqueueRequest(), useRequestQueue: () => true, }); @@ -224,15 +162,16 @@ describe('createQueuedRequestMiddleware', () => { describe('when enqueueRequest throws', () => { it('ends without calling next', async () => { - const middleware = createQueuedRequestMiddleware({ + const middleware = buildQueuedRequestMiddleware({ enqueueRequest: jest .fn() .mockRejectedValue(new Error('enqueuing error')), useRequestQueue: () => true, + methodsWithConfirmation: ['method_should_be_enqueued'], }); const request = { ...getRequestDefaults(), - method: 'eth_sendTransaction', + method: 'method_should_be_enqueued', }; const mockNext = jest.fn(); const mockEnd = jest.fn(); @@ -247,15 +186,16 @@ describe('createQueuedRequestMiddleware', () => { }); it('serializes processing errors and attaches them to the response', async () => { - const middleware = createQueuedRequestMiddleware({ + const middleware = buildQueuedRequestMiddleware({ enqueueRequest: jest .fn() .mockRejectedValue(new Error('enqueuing error')), useRequestQueue: () => true, + methodsWithConfirmation: ['method_should_be_enqueued'], }); const request = { ...getRequestDefaults(), - method: 'eth_sendTransaction', + method: 'method_should_be_enqueued', }; const response = getPendingResponseDefault(); @@ -275,3 +215,65 @@ describe('createQueuedRequestMiddleware', () => { }); }); }); + +/** + * Build a valid JSON-RPC request that includes all required properties + * + * @returns A valid JSON-RPC request with all required properties. + */ +function getRequestDefaults(): QueuedRequestMiddlewareJsonRpcRequest { + return { + method: 'doesnt matter', + id: 'doesnt matter', + jsonrpc: '2.0' as const, + origin: 'example.com', + networkClientId: 'mainnet', + }; +} + +/** + * Build a partial JSON-RPC response + * + * @returns A partial response request + */ +function getPendingResponseDefault(): PendingJsonRpcResponse { + return { + id: 'doesnt matter', + jsonrpc: '2.0' as const, + }; +} + +/** + * Builds a mock QueuedRequestController.enqueueRequest function + * + * @returns A mock function that calls the next request in the middleware chain + */ +function getMockEnqueueRequest() { + return jest + .fn< + ReturnType, + Parameters + >() + .mockImplementation((_request, requestNext) => requestNext()); +} + +/** + * Builds the QueuedRequestMiddleware + * + * @param overrideOptions - The optional options object. + * @returns The QueuedRequestMiddleware. + */ +function buildQueuedRequestMiddleware( + overrideOptions?: Partial< + Parameters[0] + >, +) { + const options = { + enqueueRequest: getMockEnqueueRequest(), + useRequestQueue: () => false, + methodsWithConfirmation: [], + ...overrideOptions, + }; + + return createQueuedRequestMiddleware(options); +} diff --git a/packages/queued-request-controller/src/QueuedRequestMiddleware.ts b/packages/queued-request-controller/src/QueuedRequestMiddleware.ts index e0a1f988fa..fbafa384bf 100644 --- a/packages/queued-request-controller/src/QueuedRequestMiddleware.ts +++ b/packages/queued-request-controller/src/QueuedRequestMiddleware.ts @@ -6,23 +6,6 @@ import type { Json, JsonRpcParams, JsonRpcRequest } from '@metamask/utils'; import type { QueuedRequestController } from './QueuedRequestController'; import type { QueuedRequestMiddlewareJsonRpcRequest } from './types'; -const isConfirmationMethod = (method: string) => { - const confirmationMethods = [ - 'eth_sendTransaction', - 'wallet_watchAsset', - 'wallet_switchEthereumChain', - 'eth_signTypedData_v4', - 'wallet_addEthereumChain', - 'wallet_requestPermissions', - 'wallet_requestSnaps', - 'personal_sign', - 'eth_sign', - 'eth_requestAccounts', - ]; - - return confirmationMethods.includes(method); -}; - /** * Ensure that the incoming request has the additional required request metadata. This metadata * should be attached to the request earlier in the middleware pipeline. @@ -56,21 +39,24 @@ function hasRequiredMetadata( * @param options - Configuration options. * @param options.enqueueRequest - A method for enqueueing a request. * @param options.useRequestQueue - A function that determines if the request queue feature is enabled. + * @param options.methodsWithConfirmation - A list of methods that can cause a confirmation to be presented to the user. * @returns The JSON-RPC middleware that manages queued requests. */ export const createQueuedRequestMiddleware = ({ enqueueRequest, useRequestQueue, + methodsWithConfirmation, }: { enqueueRequest: QueuedRequestController['enqueueRequest']; useRequestQueue: () => boolean; + methodsWithConfirmation: string[]; }): JsonRpcMiddleware => { return createAsyncMiddleware(async (req: JsonRpcRequest, res, next) => { hasRequiredMetadata(req); // if the request queue feature is turned off, or this method is not a confirmation method // bypass the queue completely - if (!useRequestQueue() || !isConfirmationMethod(req.method)) { + if (!useRequestQueue() || !methodsWithConfirmation.includes(req.method)) { return await next(); }