From b96f94b9e67577c123c4ff889efe055b7641bc60 Mon Sep 17 00:00:00 2001 From: jiexi Date: Fri, 22 Sep 2023 11:46:18 -0700 Subject: [PATCH] TokensController.addToken use networkClientId (#1676) Currently, TokenController only uses the global selected provider. We want to see how it feels making this controller consume any network provider the NetworkController may have available. We do this by extending the `addToken()` interface with an optional `networkClientId` param, but continuing to use the provider proxy as a fallback. ~~TODO: specs have not been updated to test the networkClientId case~~ added * Fixes https://github.com/MetaMask/MetaMask-planning/issues/1020 * See https://github.com/MetaMask/metamask-extension/pull/20916 ## Explanation ## References ## Changelog ### `@metamask/assets-controllers` - **BREAKING**: `TokensController` now expects `getNetworkClientById` in constructor options - **BREAKING**: `TokensController. addToken()` now accepts a single options object ``` { address: string; symbol: string; decimals: number; name?: string; image?: string; interactingAddress?: string; networkClientId?: NetworkClientId; } ``` - **CHANGED**: `TokensController. addToken()` will use the chain ID value derived from state for `networkClientId` if provided - **CHANGED**: `TokensController. addTokens()` now accepts an optional `networkClientId` as the last parameter - **CHANGED**: `TokensController. addTokens()` will use the chain ID value derived from state for `networkClientId` if provided - **CHANGED**: `TokensController. watchAsset()` options now accepts optional `networkClientId` which is used to get the ERC-20 token name if provided ## Checklist - [x] I've updated the test suite for new or updated code as appropriate - [x] I've updated documentation (JSDoc, Markdown, etc.) for new or updated code as appropriate - [x] I've highlighted breaking changes using the "BREAKING" category above as appropriate --- .../src/TokenBalancesController.test.ts | 9 +- .../src/TokenDetectionController.test.ts | 37 +- .../src/TokensController.test.ts | 397 ++++++++++++++---- .../src/TokensController.ts | 136 ++++-- 4 files changed, 443 insertions(+), 136 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 564e4209b9..2be2e5c0fc 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -146,6 +146,7 @@ describe('TokenBalancesController', () => { messenger.subscribe('NetworkController:stateChange', listener), onTokenListStateChange: sinon.stub(), getERC20TokenName: sinon.stub(), + getNetworkClientById: sinon.stub() as any, messenger: undefined as unknown as TokensControllerMessenger, }); const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; @@ -185,6 +186,7 @@ describe('TokenBalancesController', () => { messenger.subscribe('NetworkController:stateChange', listener), onTokenListStateChange: sinon.stub(), getERC20TokenName: sinon.stub(), + getNetworkClientById: sinon.stub() as any, messenger: undefined as unknown as TokensControllerMessenger, }); const errorMsg = 'Failed to get balance'; @@ -241,6 +243,7 @@ describe('TokenBalancesController', () => { messenger.subscribe('NetworkController:stateChange', listener), onTokenListStateChange: sinon.stub(), getERC20TokenName: sinon.stub(), + getNetworkClientById: sinon.stub() as any, messenger: undefined as unknown as TokensControllerMessenger, }); @@ -256,7 +259,11 @@ describe('TokenBalancesController', () => { { interval: 1337 }, ); const updateBalances = sinon.stub(tokenBalances, 'updateBalances'); - await tokensController.addToken('0x00', 'FOO', 18); + await tokensController.addToken({ + address: '0x00', + symbol: 'FOO', + decimals: 18, + }); const { tokens } = tokensController.state; const found = tokens.filter((token: Token) => token.address === '0x00'); expect(found.length > 0).toBe(true); diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 60792bc36e..36c0e4176a 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -170,8 +170,8 @@ describe('TokenDetectionController', () => { preferences = new PreferencesController({}, { useTokenDetection: true }); controllerMessenger = getControllerMessenger(); sinon - .stub(TokensController.prototype, '_instantiateNewEthersProvider') - .callsFake(() => null); + .stub(TokensController.prototype, '_createEthersContract') + .callsFake(() => null as any); tokensController = new TokensController({ chainId: ChainId.mainnet, @@ -180,6 +180,7 @@ describe('TokenDetectionController', () => { onNetworkStateChangeListeners.push(listener), onTokenListStateChange: sinon.stub(), getERC20TokenName: sinon.stub(), + getNetworkClientById: sinon.stub() as any, messenger: undefined as unknown as TokensControllerMessenger, }); @@ -336,18 +337,18 @@ describe('TokenDetectionController', () => { await tokenDetection.start(); - await tokensController.addToken( - sampleTokenA.address, - sampleTokenA.symbol, - sampleTokenA.decimals, - ); + await tokensController.addToken({ + address: sampleTokenA.address, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + }); - await tokensController.addToken( - sampleTokenB.address, - sampleTokenB.symbol, - sampleTokenB.decimals, - { name: sampleTokenB.name }, - ); + await tokensController.addToken({ + address: sampleTokenB.address, + symbol: sampleTokenB.symbol, + decimals: sampleTokenB.decimals, + name: sampleTokenB.name, + }); tokensController.ignoreTokens([sampleTokenA.address]); @@ -368,11 +369,11 @@ describe('TokenDetectionController', () => { await tokenDetection.start(); - await tokensController.addToken( - sampleTokenA.address, - sampleTokenA.symbol, - sampleTokenA.decimals, - ); + await tokensController.addToken({ + address: sampleTokenA.address, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + }); tokensController.ignoreTokens([sampleTokenA.address]); diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 8faab27b2a..99db8d64f8 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -92,12 +92,9 @@ describe('TokensController', () => { selectedAddress: defaultSelectedAddress, }, getERC20TokenName: sinon.stub(), + getNetworkClientById: sinon.stub() as any, messenger, }); - - sinon - .stub(tokensController, '_instantiateNewEthersProvider') - .callsFake(() => null); }); afterEach(() => { @@ -117,7 +114,11 @@ describe('TokensController', () => { it('should add a token', async () => { const stub = stubCreateEthers(tokensController, false); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); expect(tokensController.state.tokens[0]).toStrictEqual({ address: '0x01', decimals: 2, @@ -128,7 +129,11 @@ describe('TokensController', () => { aggregators: [], name: undefined, }); - await tokensController.addToken('0x01', 'baz', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'baz', + decimals: 2, + }); expect(tokensController.state.tokens[0]).toStrictEqual({ address: '0x01', decimals: 2, @@ -293,7 +298,11 @@ describe('TokensController', () => { const secondAddress = '0x321'; preferences.update({ selectedAddress: firstAddress }); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); preferences.update({ selectedAddress: secondAddress }); expect(tokensController.state.tokens).toHaveLength(0); preferences.update({ selectedAddress: firstAddress }); @@ -314,7 +323,11 @@ describe('TokensController', () => { it('should add token by network', async () => { const stub = stubCreateEthers(tokensController, false); changeNetwork(SEPOLIA); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); changeNetwork(GOERLI); expect(tokensController.state.tokens).toHaveLength(0); @@ -334,9 +347,51 @@ describe('TokensController', () => { stub.restore(); }); + it('should add token to the correct chainId when passed a networkClientId', async () => { + const stub = stubCreateEthers(tokensController, false); + const getNetworkClientByIdStub = jest + .spyOn(tokensController as any, 'getNetworkClientById') + .mockReturnValue({ configuration: { chainId: '0x5' } }); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + networkClientId: 'networkClientId1', + }); + expect(tokensController.state.tokens[0]).toStrictEqual({ + address: '0x01', + decimals: 2, + image: + 'https://static.metafi.codefi.network/api/v1/tokenIcons/5/0x01.png', + symbol: 'bar', + isERC721: false, + aggregators: [], + name: undefined, + }); + expect(tokensController.state.allTokens['0x5']['0x1']).toStrictEqual([ + { + address: '0x01', + decimals: 2, + image: + 'https://static.metafi.codefi.network/api/v1/tokenIcons/5/0x01.png', + symbol: 'bar', + isERC721: false, + aggregators: [], + name: undefined, + }, + ]); + + expect(getNetworkClientByIdStub).toHaveBeenCalledWith('networkClientId1'); + stub.restore(); + }); + it('should remove token', async () => { const stub = stubCreateEthers(tokensController, false); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); tokensController.ignoreTokens(['0x01']); expect(tokensController.state.tokens).toHaveLength(0); stub.restore(); @@ -347,9 +402,17 @@ describe('TokensController', () => { const firstAddress = '0x123'; const secondAddress = '0x321'; preferences.update({ selectedAddress: firstAddress }); - await tokensController.addToken('0x02', 'baz', 2); + await tokensController.addToken({ + address: '0x02', + symbol: 'baz', + decimals: 2, + }); preferences.update({ selectedAddress: secondAddress }); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); tokensController.ignoreTokens(['0x01']); expect(tokensController.state.tokens).toHaveLength(0); preferences.update({ selectedAddress: firstAddress }); @@ -369,9 +432,17 @@ describe('TokensController', () => { it('should remove token by provider type', async () => { const stub = stubCreateEthers(tokensController, false); changeNetwork(SEPOLIA); - await tokensController.addToken('0x02', 'baz', 2); + await tokensController.addToken({ + address: '0x02', + symbol: 'baz', + decimals: 2, + }); changeNetwork(GOERLI); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); tokensController.ignoreTokens(['0x01']); expect(tokensController.state.tokens).toHaveLength(0); changeNetwork(SEPOLIA); @@ -412,14 +483,26 @@ describe('TokensController', () => { }); it('should remove token from ignoredTokens/allIgnoredTokens lists if added back via addToken', async () => { - await tokensController.addToken('0x01', 'foo', 2); - await tokensController.addToken('0xFAa', 'bar', 3); + await tokensController.addToken({ + address: '0x01', + symbol: 'foo', + decimals: 2, + }); + await tokensController.addToken({ + address: '0xFAa', + symbol: 'bar', + decimals: 3, + }); expect(tokensController.state.ignoredTokens).toHaveLength(0); expect(tokensController.state.tokens).toHaveLength(2); tokensController.ignoreTokens(['0x01']); expect(tokensController.state.tokens).toHaveLength(1); expect(tokensController.state.ignoredTokens).toHaveLength(1); - await tokensController.addToken('0x01', 'baz', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'baz', + decimals: 2, + }); expect(tokensController.state.tokens).toHaveLength(2); expect(tokensController.state.ignoredTokens).toHaveLength(0); }); @@ -428,8 +511,16 @@ describe('TokensController', () => { const selectedAddress = '0x0001'; preferences.setSelectedAddress(selectedAddress); changeNetwork(SEPOLIA); - await tokensController.addToken('0x01', 'bar', 2); - await tokensController.addToken('0xFAa', 'bar', 3); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); + await tokensController.addToken({ + address: '0xFAa', + symbol: 'bar', + decimals: 3, + }); expect(tokensController.state.ignoredTokens).toHaveLength(0); expect(tokensController.state.tokens).toHaveLength(2); tokensController.ignoreTokens(['0x01']); @@ -451,7 +542,11 @@ describe('TokensController', () => { }); it('should be able to clear the ignoredToken list', async () => { - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); expect(tokensController.state.ignoredTokens).toHaveLength(0); tokensController.ignoreTokens(['0x01']); expect(tokensController.state.tokens).toHaveLength(0); @@ -474,7 +569,11 @@ describe('TokensController', () => { preferences.setSelectedAddress(selectedAddress1); changeNetwork(SEPOLIA); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); expect(tokensController.state.ignoredTokens).toHaveLength(0); tokensController.ignoreTokens(['0x01']); expect(tokensController.state.tokens).toHaveLength(0); @@ -483,13 +582,21 @@ describe('TokensController', () => { changeNetwork(GOERLI); expect(tokensController.state.ignoredTokens).toHaveLength(0); - await tokensController.addToken('0x02', 'bazz', 3); + await tokensController.addToken({ + address: '0x02', + symbol: 'bazz', + decimals: 3, + }); tokensController.ignoreTokens(['0x02']); expect(tokensController.state.ignoredTokens).toStrictEqual(['0x02']); preferences.setSelectedAddress(selectedAddress2); expect(tokensController.state.ignoredTokens).toHaveLength(0); - await tokensController.addToken('0x03', 'foo', 4); + await tokensController.addToken({ + address: '0x03', + symbol: 'foo', + decimals: 4, + }); tokensController.ignoreTokens(['0x03']); expect(tokensController.state.ignoredTokens).toStrictEqual(['0x03']); @@ -507,8 +614,16 @@ describe('TokensController', () => { it('should ignore multiple tokens with single ignoreTokens call', async () => { const stub = stubCreateEthers(tokensController, false); - await tokensController.addToken('0x01', 'A', 4); - await tokensController.addToken('0x02', 'B', 5); + await tokensController.addToken({ + address: '0x01', + symbol: 'A', + decimals: 4, + }); + await tokensController.addToken({ + address: '0x02', + symbol: 'B', + decimals: 5, + }); expect(tokensController.state.tokens).toStrictEqual([ { address: '0x01', @@ -614,7 +729,7 @@ describe('TokensController', () => { ); const address = erc721ContractAddresses[0]; const { symbol, decimals } = contractMaps[address]; - await tokensController.addToken(address, symbol, decimals); + await tokensController.addToken({ address, symbol, decimals }); expect(tokensController.state.tokens).toStrictEqual([ expect.objectContaining({ @@ -630,7 +745,11 @@ describe('TokensController', () => { const stub = stubCreateEthers(tokensController, true); const tokenAddress = '0xDA5584Cc586d07c7141aA427224A4Bd58E64aF7D'; - await tokensController.addToken(tokenAddress, 'REST', 4); + await tokensController.addToken({ + address: tokenAddress, + symbol: 'REST', + decimals: 4, + }); expect(tokensController.state.tokens).toStrictEqual([ { @@ -656,7 +775,7 @@ describe('TokensController', () => { const address = erc20ContractAddresses[0]; const { symbol, decimals } = contractMaps[address]; - await tokensController.addToken(address, symbol, decimals); + await tokensController.addToken({ address, symbol, decimals }); expect(tokensController.state.tokens).toStrictEqual([ expect.objectContaining({ @@ -672,7 +791,11 @@ describe('TokensController', () => { const stub = stubCreateEthers(tokensController, false); const tokenAddress = '0xDA5584Cc586d07c7141aA427224A4Bd58E64aF7D'; - await tokensController.addToken(tokenAddress, 'LEST', 5); + await tokensController.addToken({ + address: tokenAddress, + symbol: 'LEST', + decimals: 5, + }); expect(tokensController.state.tokens).toStrictEqual([ { @@ -692,11 +815,11 @@ describe('TokensController', () => { it('should throw error if switching networks while adding token', async function () { const dummyTokenAddress = '0x514910771AF9Ca656af840dff83E8264EcF986CA'; - const addTokenPromise = tokensController.addToken( - dummyTokenAddress, - 'LINK', - 18, - ); + const addTokenPromise = tokensController.addToken({ + address: dummyTokenAddress, + symbol: 'LINK', + decimals: 18, + }); changeNetwork(GOERLI); await expect(addTokenPromise).rejects.toThrow( 'TokensController Error: Switched networks while adding token', @@ -718,7 +841,11 @@ describe('TokensController', () => { .persist(); await expect( - tokensController.addToken(dummyTokenAddress, 'LINK', 18), + tokensController.addToken({ + address: dummyTokenAddress, + symbol: 'LINK', + decimals: 18, + }), ).rejects.toThrow(fullErrorMessage); }); @@ -745,11 +872,11 @@ describe('TokensController', () => { dummyDetectedToken, ]); - await tokensController.addToken( - dummyDetectedToken.address, - dummyDetectedToken.symbol, - dummyDetectedToken.decimals, - ); + await tokensController.addToken({ + address: dummyDetectedToken.address, + symbol: dummyDetectedToken.symbol, + decimals: dummyDetectedToken.decimals, + }); expect(tokensController.state.detectedTokens).toStrictEqual([]); expect(tokensController.state.tokens).toStrictEqual([dummyAddedToken]); @@ -796,12 +923,12 @@ describe('TokensController', () => { }); // will add token to currently configured chainId/selectedAddress - await tokensController.addToken( - directlyAddedToken.address, - directlyAddedToken.symbol, - directlyAddedToken.decimals, - { image: directlyAddedToken.image }, - ); + await tokensController.addToken({ + address: directlyAddedToken.address, + symbol: directlyAddedToken.symbol, + decimals: directlyAddedToken.decimals, + image: directlyAddedToken.image, + }); expect(tokensController.state.allDetectedTokens).toStrictEqual({ [DETECTED_CHAINID]: { @@ -860,6 +987,39 @@ describe('TokensController', () => { expect(tokensController.state.detectedTokens).toStrictEqual([]); expect(tokensController.state.tokens).toStrictEqual(dummyAddedTokens); }); + + it('should add tokens to the correct chainId when passed a networkClientId', async () => { + const getNetworkClientByIdStub = jest + .spyOn(tokensController as any, 'getNetworkClientById') + .mockReturnValue({ configuration: { chainId: '0x5' } }); + + const dummyTokens: Token[] = [ + { + address: '0x01', + symbol: 'barA', + decimals: 2, + aggregators: [], + image: undefined, + name: undefined, + }, + { + address: '0x02', + symbol: 'barB', + decimals: 2, + aggregators: [], + image: undefined, + name: undefined, + }, + ]; + + await tokensController.addTokens(dummyTokens, 'networkClientId1'); + + expect(tokensController.state.tokens).toStrictEqual(dummyTokens); + expect(tokensController.state.allTokens['0x5']['0x1']).toStrictEqual( + dummyTokens, + ); + expect(getNetworkClientByIdStub).toHaveBeenCalledWith('networkClientId1'); + }); }); describe('_getNewAllTokensState method', () => { @@ -919,7 +1079,7 @@ describe('TokensController', () => { }); }); - describe('on watchAsset', function () { + describe('watchAsset', function () { let asset: any, type: any; const interactingAddress = '0x2'; const requestId = '12345'; @@ -943,7 +1103,7 @@ describe('TokensController', () => { it('should error if passed no type', async function () { type = undefined; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Asset of type undefined not supported', ); @@ -951,7 +1111,7 @@ describe('TokensController', () => { it('should error if asset type is not supported', async function () { type = 'ERC721'; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Asset of type ERC721 not supported', ); @@ -959,7 +1119,7 @@ describe('TokensController', () => { it('should error if address is not defined', async function () { asset.address = undefined; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Must specify address, symbol, and decimals.', ); @@ -967,7 +1127,7 @@ describe('TokensController', () => { it('should error if decimals is not defined', async function () { asset.decimals = undefined; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Must specify address, symbol, and decimals.', ); @@ -975,7 +1135,7 @@ describe('TokensController', () => { it('should error if symbol is not defined', async function () { asset.symbol = undefined; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Must specify address, symbol, and decimals.', ); @@ -983,7 +1143,7 @@ describe('TokensController', () => { it('should error if symbol is empty', async function () { asset.symbol = ''; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Must specify address, symbol, and decimals.', ); @@ -991,7 +1151,7 @@ describe('TokensController', () => { it('should error if symbol is too long', async function () { asset.symbol = 'ABCDEFGHIJKLM'; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Invalid symbol "ABCDEFGHIJKLM": longer than 11 characters.', ); @@ -999,13 +1159,13 @@ describe('TokensController', () => { it('should error if decimals is invalid', async function () { asset.decimals = -1; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow( 'Invalid decimals "-1": must be 0 <= 36.', ); asset.decimals = 37; - const result2 = tokensController.watchAsset(asset, type); + const result2 = tokensController.watchAsset({ asset, type }); await expect(result2).rejects.toThrow( 'Invalid decimals "37": must be 0 <= 36.', ); @@ -1013,20 +1173,20 @@ describe('TokensController', () => { it('should error if address is invalid', async function () { asset.address = '0x123'; - const result = tokensController.watchAsset(asset, type); + const result = tokensController.watchAsset({ asset, type }); await expect(result).rejects.toThrow('Invalid address "0x123".'); }); it('fails with an invalid type suggested', async () => { await expect( - tokensController.watchAsset( - { + tokensController.watchAsset({ + asset: { address: '0xe9f786dfdd9ae4d57e830acb52296837765f0e5b', decimals: 18, symbol: 'TKN', }, - 'ERC721', - ), + type: 'ERC721', + }), ).rejects.toThrow('Asset of type ERC721 not supported'); }); @@ -1039,7 +1199,7 @@ describe('TokensController', () => { .spyOn(messenger, 'call') .mockResolvedValue(undefined); - await tokensController.watchAsset(asset, type); + await tokensController.watchAsset({ asset, type }); expect(tokensController.state.tokens).toHaveLength(1); expect(tokensController.state.tokens).toStrictEqual([ @@ -1077,7 +1237,7 @@ describe('TokensController', () => { .spyOn(messenger, 'call') .mockResolvedValue(undefined); - await tokensController.watchAsset(asset, type, interactingAddress); + await tokensController.watchAsset({ asset, type, interactingAddress }); expect(tokensController.state.tokens).toHaveLength(0); expect(tokensController.state.tokens).toStrictEqual([]); @@ -1112,6 +1272,65 @@ describe('TokensController', () => { generateRandomIdStub.mockRestore(); }); + it('stores token correctly when passed a networkClientId', async function () { + const getNetworkClientByIdStub = jest + .spyOn(tokensController as any, 'getNetworkClientById') + .mockReturnValue({ configuration: { chainId: '0x5' } }); + const getERC20TokenNameStub = jest + .spyOn(tokensController as any, 'getERC20TokenName') + .mockReturnValue(undefined); + const generateRandomIdStub = jest + .spyOn(tokensController, '_generateRandomId') + .mockReturnValue(requestId); + + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockResolvedValue(undefined); + + await tokensController.watchAsset({ + asset, + type, + interactingAddress, + networkClientId: 'networkClientId1', + }); + + expect(tokensController.state.tokens).toHaveLength(0); + expect(tokensController.state.tokens).toStrictEqual([]); + expect( + tokensController.state.allTokens['0x5'][interactingAddress], + ).toHaveLength(1); + expect( + tokensController.state.allTokens['0x5'][interactingAddress], + ).toStrictEqual([ + { + isERC721: false, + aggregators: [], + ...asset, + }, + ]); + expect(callActionSpy).toHaveBeenCalledTimes(1); + expect(callActionSpy).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: requestId, + origin: ORIGIN_METAMASK, + type: ApprovalType.WatchAsset, + requestData: { + id: requestId, + interactingAddress, + asset, + }, + }, + true, + ); + expect(getERC20TokenNameStub).toHaveBeenCalledWith( + asset.address, + 'networkClientId1', + ); + expect(getNetworkClientByIdStub).toHaveBeenCalledWith('networkClientId1'); + generateRandomIdStub.mockRestore(); + }); + it('throws and token is not added if pending approval fails', async function () { const generateRandomIdStub = jest .spyOn(tokensController, '_generateRandomId') @@ -1122,9 +1341,9 @@ describe('TokensController', () => { .spyOn(messenger, 'call') .mockRejectedValue(new Error(errorMessage)); - await expect(tokensController.watchAsset(asset, type)).rejects.toThrow( - errorMessage, - ); + await expect( + tokensController.watchAsset({ asset, type }), + ).rejects.toThrow(errorMessage); expect(tokensController.state.tokens).toHaveLength(0); expect(tokensController.state.tokens).toStrictEqual([]); @@ -1152,11 +1371,23 @@ describe('TokensController', () => { it('should update tokens list when set address changes', async function () { const stub = stubCreateEthers(tokensController, false); preferences.setSelectedAddress('0x1'); - await tokensController.addToken('0x01', 'A', 4); - await tokensController.addToken('0x02', 'B', 5); + await tokensController.addToken({ + address: '0x01', + symbol: 'A', + decimals: 4, + }); + await tokensController.addToken({ + address: '0x02', + symbol: 'B', + decimals: 5, + }); preferences.setSelectedAddress('0x2'); expect(tokensController.state.tokens).toStrictEqual([]); - await tokensController.addToken('0x03', 'C', 6); + await tokensController.addToken({ + address: '0x03', + symbol: 'C', + decimals: 6, + }); preferences.setSelectedAddress('0x1'); expect(tokensController.state.tokens).toStrictEqual([ { @@ -1204,14 +1435,30 @@ describe('TokensController', () => { changeNetwork(SEPOLIA); - await tokensController.addToken('0x01', 'A', 4); - await tokensController.addToken('0x02', 'B', 5); + await tokensController.addToken({ + address: '0x01', + symbol: 'A', + decimals: 4, + }); + await tokensController.addToken({ + address: '0x02', + symbol: 'B', + decimals: 5, + }); const initialTokensFirst = tokensController.state.tokens; changeNetwork(GOERLI); - await tokensController.addToken('0x03', 'C', 4); - await tokensController.addToken('0x04', 'D', 5); + await tokensController.addToken({ + address: '0x03', + symbol: 'C', + decimals: 4, + }); + await tokensController.addToken({ + address: '0x04', + symbol: 'D', + decimals: 5, + }); const initialTokensSecond = tokensController.state.tokens; @@ -1332,7 +1579,11 @@ describe('TokensController', () => { describe('onTokenListStateChange', () => { it('onTokenListChange', async () => { const stub = stubCreateEthers(tokensController, false); - await tokensController.addToken('0x01', 'bar', 2); + await tokensController.addToken({ + address: '0x01', + symbol: 'bar', + decimals: 2, + }); expect(tokensController.state.tokens[0]).toStrictEqual({ address: '0x01', decimals: 2, diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index 3a57534292..33c1ee9c5d 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -16,7 +16,11 @@ import { ERC20, } from '@metamask/controller-utils'; import { abiERC721 } from '@metamask/metamask-eth-abis'; -import type { NetworkState } from '@metamask/network-controller'; +import type { + NetworkClientId, + NetworkController, + NetworkState, +} from '@metamask/network-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { AbortController as WhatwgAbortController } from 'abort-controller'; @@ -121,8 +125,6 @@ export class TokensController extends BaseController< > { private readonly mutex = new Mutex(); - private ethersProvider: any; - private abortController: WhatwgAbortController; private readonly messagingSystem: TokensControllerMessenger; @@ -166,6 +168,8 @@ export class TokensController extends BaseController< private readonly getERC20TokenName: AssetsContractController['getERC20TokenName']; + private readonly getNetworkClientById: NetworkController['getNetworkClientById']; + /** * Creates a TokensController instance. * @@ -175,6 +179,7 @@ export class TokensController extends BaseController< * @param options.onNetworkStateChange - Allows subscribing to network controller state changes. * @param options.onTokenListStateChange - Allows subscribing to token list controller state changes. * @param options.getERC20TokenName - Gets the ERC-20 token name. + * @param options.getNetworkClientById - Gets the network client with the given id from the NetworkController. * @param options.config - Initial options used to configure this controller. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messenger. @@ -185,6 +190,7 @@ export class TokensController extends BaseController< onNetworkStateChange, onTokenListStateChange, getERC20TokenName, + getNetworkClientById, config, state, messenger, @@ -200,6 +206,7 @@ export class TokensController extends BaseController< listener: (tokenListState: TokenListState) => void, ) => void; getERC20TokenName: AssetsContractController['getERC20TokenName']; + getNetworkClientById: NetworkController['getNetworkClientById']; config?: Partial; state?: Partial; messenger: TokensControllerMessenger; @@ -226,6 +233,7 @@ export class TokensController extends BaseController< this.initialize(); this.abortController = new WhatwgAbortController(); this.getERC20TokenName = getERC20TokenName; + this.getNetworkClientById = getNetworkClientById; this.messagingSystem = messenger; @@ -247,7 +255,6 @@ export class TokensController extends BaseController< this.abortController.abort(); this.abortController = new WhatwgAbortController(); this.configure({ chainId }); - this.ethersProvider = this._instantiateNewEthersProvider(); this.update({ tokens: allTokens[chainId]?.[selectedAddress] || [], ignoredTokens: allIgnoredTokens[chainId]?.[selectedAddress] || [], @@ -263,34 +270,44 @@ export class TokensController extends BaseController< }); } - _instantiateNewEthersProvider(): any { - return new Web3Provider(this.config?.provider); - } - /** * Adds a token to the stored token list. * - * @param address - Hex address of the token contract. - * @param symbol - Symbol of the token. - * @param decimals - Number of decimals the token uses. - * @param options - Object containing name and image of the token - * @param options.name - Name of the token - * @param options.image - Image of the token + * @param options - The method argument object. + * @param options.address - Hex address of the token contract. + * @param options.symbol - Symbol of the token. + * @param options.decimals - Number of decimals the token uses. + * @param options.name - Name of the token. + * @param options.image - Image of the token. * @param options.interactingAddress - The address of the account to add a token to. + * @param options.networkClientId - Network Client ID. * @returns Current token list. */ - async addToken( - address: string, - symbol: string, - decimals: number, - { - name, - image, - interactingAddress, - }: { name?: string; image?: string; interactingAddress?: string } = {}, - ): Promise { + async addToken({ + address, + symbol, + decimals, + name, + image, + interactingAddress, + networkClientId, + }: { + address: string; + symbol: string; + decimals: number; + name?: string; + image?: string; + interactingAddress?: string; + networkClientId?: NetworkClientId; + }): Promise { const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - const { chainId: currentChainId, selectedAddress } = this.config; + const { chainId, selectedAddress } = this.config; + let currentChainId = chainId; + if (networkClientId) { + currentChainId = + this.getNetworkClientById(networkClientId).configuration.chainId; + } + const accountAddress = interactingAddress || selectedAddress; const isInteractingWithWalletAccount = accountAddress === selectedAddress; const releaseLock = await this.mutex.acquire(); @@ -304,10 +321,10 @@ export class TokensController extends BaseController< allDetectedTokens[currentChainId]?.[accountAddress] || []; const newTokens: Token[] = [...tokens]; const [isERC721, tokenMetadata] = await Promise.all([ - this._detectIsERC721(address), + this._detectIsERC721(address, networkClientId), this.fetchTokenMetadata(address), ]); - if (currentChainId !== this.config.chainId) { + if (!networkClientId && currentChainId !== this.config.chainId) { throw new Error( 'TokensController Error: Switched networks while adding token', ); @@ -319,7 +336,7 @@ export class TokensController extends BaseController< image: image || formatIconUrlWithProxy({ - chainId: this.config.chainId, + chainId: currentChainId, tokenAddress: address, }), isERC721, @@ -349,6 +366,7 @@ export class TokensController extends BaseController< newIgnoredTokens, newDetectedTokens, interactingAddress: accountAddress, + interactingChainId: currentChainId, }); let newState: Partial = { @@ -378,8 +396,9 @@ export class TokensController extends BaseController< * Add a batch of tokens. * * @param tokensToImport - Array of tokens to import. + * @param networkClientId - Optional network client ID used to determine interacting chain ID. */ - async addTokens(tokensToImport: Token[]) { + async addTokens(tokensToImport: Token[], networkClientId?: NetworkClientId) { const releaseLock = await this.mutex.acquire(); const { tokens, detectedTokens, ignoredTokens } = this.state; const importedTokensMap: { [key: string]: true } = {}; @@ -414,11 +433,18 @@ export class TokensController extends BaseController< (tokenAddress) => !newTokensMap[tokenAddress.toLowerCase()], ); + let interactingChainId; + if (networkClientId) { + interactingChainId = + this.getNetworkClientById(networkClientId).configuration.chainId; + } + const { newAllTokens, newAllDetectedTokens, newAllIgnoredTokens } = this._getNewAllTokensState({ newTokens, newDetectedTokens, newIgnoredTokens, + interactingChainId, }); this.update({ @@ -619,10 +645,14 @@ export class TokensController extends BaseController< * Detects whether or not a token is ERC-721 compatible. * * @param tokenAddress - The token contract address. + * @param networkClientId - Optional network client ID to fetch contract info with. * @returns A boolean indicating whether the token address passed in supports the EIP-721 * interface. */ - async _detectIsERC721(tokenAddress: string) { + async _detectIsERC721( + tokenAddress: string, + networkClientId?: NetworkClientId, + ) { const checksumAddress = toChecksumHexAddress(tokenAddress); // if this token is already in our contract metadata map we don't need // to check against the contract @@ -635,7 +665,7 @@ export class TokensController extends BaseController< const tokenContract = this._createEthersContract( tokenAddress, abiERC721, - this.ethersProvider, + networkClientId, ); try { return await tokenContract.supportsInterface(ERC721_INTERFACE_ID); @@ -651,9 +681,14 @@ export class TokensController extends BaseController< _createEthersContract( tokenAddress: string, abi: string, - ethersProvider: any, + networkClientId?: NetworkClientId, ): Contract { - const tokenContract = new Contract(tokenAddress, abi, ethersProvider); + const provider = networkClientId + ? this.getNetworkClientById(networkClientId).provider + : this.config?.provider; + + const web3provider = new Web3Provider(provider); + const tokenContract = new Contract(tokenAddress, abi, web3provider); return tokenContract; } @@ -665,16 +700,24 @@ export class TokensController extends BaseController< * Adds a new suggestedAsset to the list of watched assets. * Parameters will be validated according to the asset type being watched. * - * @param asset - The asset to be watched. For now only ERC20 tokens are accepted. - * @param type - The asset type. - * @param interactingAddress - The address of the account that is requesting to watch the asset. + * @param options - The method options. + * @param options.asset - The asset to be watched. For now only ERC20 tokens are accepted. + * @param options.type - The asset type. + * @param options.interactingAddress - The address of the account that is requesting to watch the asset. + * @param options.networkClientId - Network Client ID. * @returns Object containing a Promise resolving to the suggestedAsset address if accepted. */ - async watchAsset( - asset: Token, - type: string, - interactingAddress?: string, - ): Promise { + async watchAsset({ + asset, + type, + interactingAddress, + networkClientId, + }: { + asset: Token; + type: string; + interactingAddress?: string; + networkClientId?: NetworkClientId; + }): Promise { if (type !== ERC20) { throw new Error(`Asset of type ${type} not supported`); } @@ -695,15 +738,20 @@ export class TokensController extends BaseController< let name; try { - name = await this.getERC20TokenName(asset.address); + name = await this.getERC20TokenName(asset.address, networkClientId); } catch (error) { name = undefined; } - await this.addToken(asset.address, asset.symbol, asset.decimals, { + const { address, symbol, decimals, image } = asset; + await this.addToken({ + address, + symbol, + decimals, name, - image: asset.image, + image, interactingAddress: suggestedAssetMeta.interactingAddress, + networkClientId, }); }