From efbec8402a28b9906450c4ed893f7a7cdc0e7d2a Mon Sep 17 00:00:00 2001 From: Prithpal Sooriya Date: Tue, 17 Feb 2026 17:31:16 +0000 Subject: [PATCH 1/3] feat: add mock asset controller messenger for testing This commit introduces a new mock implementation of the AssetControllerMessenger, which facilitates testing by simulating various actions and events related to asset management. The mock includes methods for creating a mock provider, registering action handlers, and simulating network states. Additionally, it updates the RpcDataSource and StakedBalanceDataSource tests to utilize the new mock, improving test reliability and coverage. Refactors existing tests to remove unnecessary complexity and improve clarity, ensuring that the mock accurately reflects the expected behavior of the real messenger in a controlled environment. --- .../MockAssetControllerMessenger.ts | 253 ++++++++++ .../src/data-sources/RpcDataSource.test.ts | 232 ++------- .../src/data-sources/RpcDataSource.ts | 14 +- .../StakedBalanceDataSource.test.ts | 462 +++++++----------- .../data-sources/StakedBalanceDataSource.ts | 43 +- .../services/StakedBalanceFetcher.ts | 2 +- 6 files changed, 500 insertions(+), 506 deletions(-) create mode 100644 packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts diff --git a/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts b/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts new file mode 100644 index 00000000000..afc0778ff73 --- /dev/null +++ b/packages/assets-controller/src/__fixtures__/MockAssetControllerMessenger.ts @@ -0,0 +1,253 @@ +import { defaultAbiCoder } from '@ethersproject/abi'; +import * as ProviderModule from '@ethersproject/providers'; +import { + MOCK_ANY_NAMESPACE, + Messenger, + MessengerActions, + MessengerEvents, + MockAnyNamespace, +} from '@metamask/messenger'; +import { NetworkStatus } from '@metamask/network-controller'; + +import { + NetworkState, + RpcEndpoint, + RpcEndpointType, +} from '../../../network-controller/src/NetworkController'; +import { + AssetsControllerMessenger, + getDefaultAssetsControllerState, +} from '../AssetsController'; +import { STAKING_INTERFACE } from '../data-sources/evm-rpc-services/services/StakedBalanceFetcher'; + +// Test escape hatch for mocking areas that do not need explicit types +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type TestMockType = any; + +export type MockRootMessenger = Messenger< + MockAnyNamespace, + MessengerActions, + MessengerEvents +>; + +const MAINNET_CHAIN_ID_HEX = '0x1'; +const MOCK_CHAIN_ID_CAIP = 'eip155:1'; + +export function createMockAssetControllerMessenger(): { + rootMessenger: MockRootMessenger; + assetsControllerMessenger: AssetsControllerMessenger; +} { + const rootMessenger: MockRootMessenger = new Messenger({ + namespace: MOCK_ANY_NAMESPACE, + }); + + const assetsControllerMessenger: AssetsControllerMessenger = new Messenger({ + namespace: 'AssetsController', + parent: rootMessenger, + }); + + rootMessenger.delegate({ + messenger: assetsControllerMessenger, + actions: [ + // AssetsController + 'AccountTreeController:getAccountsFromSelectedAccountGroup', + // RpcDataSource + 'TokenListController:getState', + 'NetworkController:getState', + 'NetworkController:getNetworkClientById', + // RpcDataSource, StakedBalanceDataSource + 'NetworkEnablementController:getState', + // SnapDataSource + 'SnapController:getRunnableSnaps', + 'SnapController:handleRequest', + 'PermissionController:getPermissions', + // BackendWebsocketDataSource + 'BackendWebSocketService:connect', + 'BackendWebSocketService:disconnect', + 'BackendWebSocketService:forceReconnection', + 'BackendWebSocketService:sendMessage', + 'BackendWebSocketService:sendRequest', + 'BackendWebSocketService:getConnectionInfo', + 'BackendWebSocketService:getSubscriptionsByChannel', + 'BackendWebSocketService:channelHasSubscription', + 'BackendWebSocketService:findSubscriptionsByChannelPrefix', + 'BackendWebSocketService:addChannelCallback', + 'BackendWebSocketService:removeChannelCallback', + 'BackendWebSocketService:getChannelCallbacks', + 'BackendWebSocketService:subscribe', + ], + events: [ + // AssetsController + 'AccountTreeController:selectedAccountGroupChange', + 'KeyringController:lock', + 'KeyringController:unlock', + 'PreferencesController:stateChange', + // RpcDataSource, StakedBalanceDataSource + 'NetworkController:stateChange', + 'TransactionController:transactionConfirmed', + 'TransactionController:incomingTransactionsReceived', + // StakedBalanceDataSource + 'NetworkEnablementController:stateChange', + // SnapDataSource + 'AccountsController:accountBalancesUpdated', + 'PermissionController:stateChange', + // BackendWebsocketDataSource + 'BackendWebSocketService:connectionStateChanged', + ], + }); + + return { + rootMessenger, + assetsControllerMessenger, + }; +} + +export function registerStakedMessengerActions( + rootMessenger: MockRootMessenger, + opts = { + enabledNetworkMap: { eip155: { [MAINNET_CHAIN_ID_HEX]: true } } as Record< + string, + Record + >, + mockProvider: createMockWeb3Provider({ + sharesWei: '1000000000000000000', + assetsWei: '1500000000000000000', + }), + }, +): void { + rootMessenger.registerActionHandler( + 'NetworkEnablementController:getState', + () => ({ + enabledNetworkMap: opts.enabledNetworkMap, + nativeAssetIdentifiers: {}, + }), + ); + + rootMessenger.registerActionHandler( + 'NetworkController:getNetworkClientById', + () => + ({ + provider: opts.mockProvider, + configuration: { chainId: MAINNET_CHAIN_ID_HEX }, + }) as TestMockType, + ); + + rootMessenger.registerActionHandler('NetworkController:getState', () => ({ + networkConfigurationsByChainId: { + [MAINNET_CHAIN_ID_HEX]: { + chainId: MAINNET_CHAIN_ID_HEX, + rpcEndpoints: [{ networkClientId: 'mainnet' }] as RpcEndpoint[], + defaultRpcEndpointIndex: 0, + blockExplorerUrls: [], + name: 'Mainnet', + nativeCurrency: 'ETH', + }, + }, + networksMetadata: {}, + selectedNetworkClientId: 'mainnet', + })); +} + +export function registerRpcDataSourceActions( + rootMessenger: MockRootMessenger, + opts?: { + networkState?: NetworkState; + }, +): void { + rootMessenger.registerActionHandler( + 'NetworkController:getState', + () => opts?.networkState ?? createMockNetworkState(), + ); + + rootMessenger.registerActionHandler( + 'NetworkController:getNetworkClientById', + () => + ({ + provider: { request: jest.fn().mockResolvedValue('0x0') }, + configuration: { chainId: MAINNET_CHAIN_ID_HEX }, + }) as TestMockType, + ); + + rootMessenger.registerActionHandler('AssetsController:getState', () => + getDefaultAssetsControllerState(), + ); + + rootMessenger.registerActionHandler('TokenListController:getState', () => ({ + tokensChainsCache: {}, + })); + + rootMessenger.registerActionHandler( + 'NetworkEnablementController:getState', + () => ({ + enabledNetworkMap: {}, + nativeAssetIdentifiers: { + [MOCK_CHAIN_ID_CAIP]: `${MOCK_CHAIN_ID_CAIP}/slip44:60`, + }, + }), + ); +} + +export function createMockWeb3Provider( + options = { + sharesWei: '1000000000000000000', + assetsWei: '1500000000000000000', + }, +): jest.SpyInstance { + const mockProvider = jest.spyOn(ProviderModule, 'Web3Provider'); + + const mockCalls = jest.fn().mockImplementation((callData) => { + // Will decode and return mock shares or throw + try { + STAKING_INTERFACE.decodeFunctionData('getShares', callData.data); + return defaultAbiCoder.encode(['uint256'], [options.sharesWei]); + } catch { + // do nothing + } + + // Will decode and return mock assets or throw + try { + STAKING_INTERFACE.decodeFunctionData('convertToAssets', callData.data); + return defaultAbiCoder.encode(['uint256'], [options.assetsWei]); + } catch { + // do nothing + } + + throw new Error('MOCK FAILURE: Invalid function data'); + }); + + mockProvider.mockReturnValue({ + call: mockCalls, + } as unknown as ProviderModule.Web3Provider); + + return mockProvider; +} + +export function createMockNetworkState( + chainStatus: NetworkStatus = NetworkStatus.Available, +): NetworkState { + return { + selectedNetworkClientId: 'mainnet', + networkConfigurationsByChainId: { + [MAINNET_CHAIN_ID_HEX]: { + chainId: MAINNET_CHAIN_ID_HEX, + name: 'Mainnet', + nativeCurrency: 'ETH', + defaultRpcEndpointIndex: 0, + rpcEndpoints: [ + { + networkClientId: 'mainnet', + url: 'https://mainnet.infura.io', + type: RpcEndpointType.Custom, + }, + ], + blockExplorerUrls: [], + }, + }, + networksMetadata: { + mainnet: { + status: chainStatus, + EIPS: {}, + }, + }, + } as unknown as NetworkState; +} diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts index 167f7008ee8..369725b4c07 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.test.ts @@ -1,32 +1,18 @@ /* eslint-disable jest/unbound-method */ import type { InternalAccount } from '@metamask/keyring-internal-api'; -import { Messenger, MOCK_ANY_NAMESPACE } from '@metamask/messenger'; -import type { ActionHandler, MockAnyNamespace } from '@metamask/messenger'; -import type { - AutoManagedNetworkClient, - CustomNetworkClientConfiguration, - NetworkState, -} from '@metamask/network-controller'; +import type { NetworkState } from '@metamask/network-controller'; +import { NetworkStatus, RpcEndpointType } from '@metamask/network-controller'; + +import type { RpcDataSourceOptions } from './RpcDataSource'; +import { RpcDataSource } from './RpcDataSource'; import { - NetworkClientType, - NetworkStatus, - RpcEndpointType, -} from '@metamask/network-controller'; - -import type { - RpcDataSourceOptions, - RpcDataSourceAllowedActions, - RpcDataSourceAllowedEvents, -} from './RpcDataSource'; -import { RpcDataSource, createRpcDataSource } from './RpcDataSource'; + createMockAssetControllerMessenger, + MockRootMessenger, + registerRpcDataSourceActions, +} from '../__fixtures__/MockAssetControllerMessenger'; import type { AssetsControllerMessenger } from '../AssetsController'; -import { getDefaultAssetsControllerState } from '../AssetsController'; import type { ChainId, DataRequest, Context } from '../types'; -type AllActions = RpcDataSourceAllowedActions; -type AllEvents = RpcDataSourceAllowedEvents; -type RootMessenger = Messenger; - const MOCK_CHAIN_ID_HEX = '0x1'; const MOCK_CHAIN_ID_CAIP = 'eip155:1' as ChainId; const MOCK_ACCOUNT_ID = 'mock-account-id'; @@ -35,12 +21,6 @@ type EthereumProvider = { request: jest.Mock; }; -function createMockProvider(): EthereumProvider { - return { - request: jest.fn().mockResolvedValue('0x0'), - }; -} - function createMockInternalAccount( overrides?: Partial, ): InternalAccount { @@ -128,7 +108,8 @@ type WithControllerCallback = ({ onActiveChainsUpdated, }: { controller: RpcDataSource; - messenger: RootMessenger; + rootMessenger: MockRootMessenger; + messenger: AssetsControllerMessenger; onActiveChainsUpdated: ( dataSourceName: string, chains: ChainId[], @@ -149,105 +130,27 @@ async function withController( | [WithControllerCallback] ): Promise { const [controllerOptions, fn] = args.length === 2 ? args : [{}, args[0]]; - const { - options = {}, - networkState = createMockNetworkState(), - actionHandlerOverrides = {}, - } = controllerOptions; - - const messenger: RootMessenger = new Messenger({ - namespace: MOCK_ANY_NAMESPACE, - }); + const { options = {}, networkState = createMockNetworkState() } = + controllerOptions; - const rpcDataSourceMessenger = new Messenger< - 'RpcDataSource', - AllActions, - AllEvents, - RootMessenger - >({ - namespace: 'RpcDataSource', - parent: messenger, - }); + const { rootMessenger, assetsControllerMessenger } = + createMockAssetControllerMessenger(); + registerRpcDataSourceActions(rootMessenger, { networkState }); - messenger.delegate({ - messenger: rpcDataSourceMessenger, - actions: [ - 'NetworkController:getState', - 'NetworkController:getNetworkClientById', - 'AssetsController:getState', - 'TokenListController:getState', - 'NetworkEnablementController:getState', - ], - events: ['NetworkController:stateChange'], - }); - - // Mock NetworkController:getState - messenger.registerActionHandler( - 'NetworkController:getState', - actionHandlerOverrides['NetworkController:getState'] ?? - ((): NetworkState => networkState), - ); - - // Mock NetworkController:getNetworkClientById (minimal shape; full type not needed for tests) - const getNetworkClientByIdHandler = - actionHandlerOverrides['NetworkController:getNetworkClientById'] ?? - ((): AutoManagedNetworkClient => - ({ - provider: createMockProvider(), - configuration: { - chainId: MOCK_CHAIN_ID_HEX, - ticker: 'ETH', - rpcUrl: 'https://mainnet.infura.io', - type: NetworkClientType.Custom, - }, - }) as unknown as AutoManagedNetworkClient); - messenger.registerActionHandler( - 'NetworkController:getNetworkClientById', - getNetworkClientByIdHandler as ActionHandler< - RpcDataSourceAllowedActions, - 'NetworkController:getNetworkClientById' - >, - ); - - // Mock AssetsController:getState - messenger.registerActionHandler('AssetsController:getState', () => - getDefaultAssetsControllerState(), - ); - - // Mock TokenListController:getState - messenger.registerActionHandler('TokenListController:getState', () => ({ - tokensChainsCache: {}, - })); - - // Mock NetworkEnablementController:getState - messenger.registerActionHandler( - 'NetworkEnablementController:getState', - () => ({ - enabledNetworkMap: {}, - nativeAssetIdentifiers: { - [MOCK_CHAIN_ID_CAIP]: `${MOCK_CHAIN_ID_CAIP}/slip44:60`, - }, - }), - ); - - const onActiveChainsUpdated = - ( - options as { - onActiveChainsUpdated?: ( - dataSourceName: string, - chains: ChainId[], - previousChains: ChainId[], - ) => void; - } - ).onActiveChainsUpdated ?? jest.fn(); + const onActiveChainsUpdated = options.onActiveChainsUpdated ?? jest.fn(); const controller = new RpcDataSource({ - messenger: rpcDataSourceMessenger as unknown as AssetsControllerMessenger, + messenger: assetsControllerMessenger, onActiveChainsUpdated, ...options, }); try { - return await fn({ controller, messenger, onActiveChainsUpdated }); + return await fn({ + controller, + messenger: assetsControllerMessenger, + rootMessenger, + onActiveChainsUpdated, + }); } finally { controller.destroy(); } @@ -345,24 +248,28 @@ describe('RpcDataSource', () => { }, }, }, - async ({ controller, messenger }) => { + async ({ controller, rootMessenger }) => { source = controller; // Trigger callback via network state change (first call is during construction, before source is set). const newNetworkState = createMockNetworkState( NetworkStatus.Available, ); - (messenger.publish as CallableFunction)( + rootMessenger.publish( 'NetworkController:stateChange', newNetworkState, [], ); await new Promise(process.nextTick); expect(callbackResult).not.toBeNull(); - const result = callbackResult as { - syncChains: ChainId[]; - newChains: ChainId[]; + const assertNotNull: ( + value: Val | null, + ) => asserts value is Val = (value) => { + expect(value).not.toBeNull(); }; - expect(result.syncChains).toStrictEqual(result.newChains); + assertNotNull(callbackResult); + expect(callbackResult.syncChains).toStrictEqual( + callbackResult.newChains, + ); const chains = await controller.getActiveChains(); expect(chains).toContain(MOCK_CHAIN_ID_CAIP); }, @@ -391,7 +298,7 @@ describe('RpcDataSource', () => { selectedNetworkClientId: 'mainnet', networkConfigurationsByChainId: {}, networksMetadata: {}, - } as unknown as NetworkState; + }; await withController( { networkState: emptyNetworkState }, @@ -424,7 +331,7 @@ describe('RpcDataSource', () => { it('returns undefined for non-existent chain', async () => { await withController(({ controller }) => { - const status = controller.getChainStatus('eip155:999' as ChainId); + const status = controller.getChainStatus('eip155:999'); expect(status).toBeUndefined(); }); }); @@ -454,10 +361,10 @@ describe('RpcDataSource', () => { accountsWithSupportedChains: [ { account, - supportedChains: ['eip155:999' as ChainId], + supportedChains: ['eip155:999'], }, ], - chainIds: ['eip155:999' as ChainId], + chainIds: ['eip155:999'], dataTypes: ['balance'], }; @@ -567,7 +474,7 @@ describe('RpcDataSource', () => { const middleware = controller.assetsMiddleware; const context: Context = { request: createDataRequest({ - chainIds: ['eip155:999' as ChainId], + chainIds: ['eip155:999'], }), response: {}, getAssetsState: jest.fn(), @@ -613,7 +520,7 @@ describe('RpcDataSource', () => { describe('network state changes', () => { it('updates chains when network state changes', async () => { - await withController(async ({ controller, messenger }) => { + await withController(async ({ controller, rootMessenger }) => { const newNetworkState = createMockNetworkState(NetworkStatus.Available); newNetworkState.networkConfigurationsByChainId['0x89'] = { chainId: '0x89', @@ -634,7 +541,7 @@ describe('RpcDataSource', () => { EIPS: {}, }; - (messenger.publish as CallableFunction)( + rootMessenger.publish( 'NetworkController:stateChange', newNetworkState, [], @@ -648,63 +555,6 @@ describe('RpcDataSource', () => { }); }); - describe('createRpcDataSource', () => { - it('creates an RpcDataSource instance', async () => { - const messenger: RootMessenger = new Messenger({ - namespace: MOCK_ANY_NAMESPACE, - }); - - const rpcDataSourceMessenger = new Messenger< - 'RpcDataSource', - AllActions, - AllEvents, - RootMessenger - >({ - namespace: 'RpcDataSource', - parent: messenger, - }); - - messenger.delegate({ - messenger: rpcDataSourceMessenger, - actions: [ - 'NetworkController:getState', - 'NetworkController:getNetworkClientById', - ], - events: ['NetworkController:stateChange'], - }); - - messenger.registerActionHandler('NetworkController:getState', () => - createMockNetworkState(), - ); - messenger.registerActionHandler( - 'NetworkController:getNetworkClientById', - (): AutoManagedNetworkClient => - ({ - provider: createMockProvider(), - configuration: { - chainId: MOCK_CHAIN_ID_HEX, - ticker: 'ETH', - rpcUrl: 'https://mainnet.infura.io', - type: NetworkClientType.Custom, - }, - }) as unknown as AutoManagedNetworkClient, - ); - - const controller = createRpcDataSource({ - messenger: - rpcDataSourceMessenger as unknown as AssetsControllerMessenger, - onActiveChainsUpdated: jest.fn(), - }); - - try { - expect(controller).toBeInstanceOf(RpcDataSource); - expect(controller.getName()).toBe('RpcDataSource'); - } finally { - controller.destroy(); - } - }); - }); - describe('instance methods', () => { it('exposes getAssetsMiddleware on instance', async () => { await withController(({ controller }) => { diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.ts b/packages/assets-controller/src/data-sources/RpcDataSource.ts index 06fd1220963..922f0bb40d8 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.ts @@ -1233,8 +1233,7 @@ export class RpcDataSource extends AbstractDataSource< 'NetworkEnablementController:getState', ); - return (nativeAssetIdentifiers[chainId] ?? - `${chainId}/slip44:60`) as Caip19AssetId; + return nativeAssetIdentifiers[chainId] ?? `${chainId}/slip44:60`; } /** @@ -1245,9 +1244,7 @@ export class RpcDataSource extends AbstractDataSource< */ #getExistingAssetsMetadata(): Record { try { - const state = this.#messenger.call('AssetsController:getState') as { - assetsInfo?: Record; - }; + const state = this.#messenger.call('AssetsController:getState'); return state.assetsInfo ?? {}; } catch { // If AssetsController:getState fails, return empty metadata @@ -1288,12 +1285,7 @@ export class RpcDataSource extends AbstractDataSource< const lowerAddress = tokenAddress.toLowerCase(); for (const [address, tokenData] of Object.entries(chainTokenList)) { if (address.toLowerCase() === lowerAddress) { - const token = tokenData as { - symbol?: string; - name?: string; - decimals?: number; - iconUrl?: string; - }; + const token = tokenData; if (token.symbol && token.decimals !== undefined) { return { type: 'erc20', diff --git a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts index 798eed5f460..457af9ac656 100644 --- a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts @@ -1,8 +1,14 @@ -import { defaultAbiCoder } from '@ethersproject/abi'; import type { InternalAccount } from '@metamask/keyring-internal-api'; +import { TransactionStatus } from '@metamask/transaction-controller'; import type { StakedBalanceDataSourceOptions } from './StakedBalanceDataSource'; import { StakedBalanceDataSource } from './StakedBalanceDataSource'; +import { + MockRootMessenger, + createMockAssetControllerMessenger, + createMockWeb3Provider, + registerStakedMessengerActions, +} from '../__fixtures__/MockAssetControllerMessenger'; import type { AssetsControllerMessenger } from '../AssetsController'; import type { AssetsControllerStateInternal, @@ -11,42 +17,6 @@ import type { DataRequest, } from '../types'; -function createMockProvider(options: { - sharesWei?: string; - assetsWei?: string; -}): { call: jest.Mock } { - const { sharesWei = '0', assetsWei = '0' } = options; - let callCount = 0; - return { - call: jest.fn().mockImplementation(async () => { - callCount += 1; - if (callCount === 1) { - return defaultAbiCoder.encode(['uint256'], [sharesWei]); - } - return defaultAbiCoder.encode(['uint256'], [assetsWei]); - }), - }; -} - -jest.mock('@ethersproject/providers', () => { - const actual = jest.requireActual('@ethersproject/providers'); - return { - ...actual, - Web3Provider: jest.fn().mockImplementation( - (provider: { - call?: jest.Mock; - }): { - call: (params: unknown) => Promise; - } => ({ - call: (params: unknown) => - provider?.call - ? Promise.resolve(provider.call(params)) - : Promise.resolve('0x0'), - }), - ), - }; -}); - const MAINNET_CHAIN_ID_HEX = '0x1'; const MAINNET_CHAIN_ID_CAIP = 'eip155:1' as ChainId; const STAKING_CONTRACT_MAINNET = '0x4FEF9D741011476750A243aC70b9789a63dd47Df'; @@ -109,21 +79,10 @@ function createMiddlewareContext(overrides?: Partial): Context { }; } -type MockMessenger = { - subscribe: jest.Mock; - call: jest.Mock; - publish: (event: string, ...args: unknown[]) => void; - getSubscribeHandlers: () => Map void>; -}; - -type NetworkEnablementState = { - enabledNetworkMap: Record>; -}; - type WithControllerOptions = { options?: Partial; enabledNetworkMap?: Record>; - mockProvider?: ReturnType; + mockProvider?: ReturnType; }; type WithControllerCallback = ({ @@ -133,72 +92,19 @@ type WithControllerCallback = ({ mockProvider, }: { controller: StakedBalanceDataSource; - messenger: MockMessenger; + messenger: AssetsControllerMessenger; + mockMessengerCall: jest.SpyInstance; + mockMessengerSubscribe: jest.SpyInstance; + mockMessengerUnsubscribe: jest.SpyInstance; + rootMessenger: MockRootMessenger; onActiveChainsUpdated: ( dataSourceName: string, chains: ChainId[], previousChains: ChainId[], ) => void; - mockProvider: ReturnType; + mockProvider: ReturnType; }) => Promise | ReturnValue; -function createMockMessenger( - mockProvider?: ReturnType, -): MockMessenger { - const subscribeHandlers: Map void> = new Map(); - const provider = mockProvider ?? createMockProvider({}); - - const messenger = { - subscribe: jest - .fn() - .mockImplementation((event: string, handler: (p: unknown) => void) => { - subscribeHandlers.set(event, handler); - return jest.fn(() => subscribeHandlers.delete(event)); - }), - call: jest.fn().mockImplementation((action: string, id?: string) => { - if (action === 'NetworkEnablementController:getState') { - return { - enabledNetworkMap: { - eip155: { [MAINNET_CHAIN_ID_HEX]: true }, - }, - } as NetworkEnablementState; - } - if (action === 'NetworkController:getState') { - return { - networkConfigurationsByChainId: { - [MAINNET_CHAIN_ID_HEX]: { - chainId: MAINNET_CHAIN_ID_HEX, - rpcEndpoints: [{ networkClientId: 'mainnet' }], - defaultRpcEndpointIndex: 0, - }, - }, - networksMetadata: {}, - }; - } - if ( - action === 'NetworkController:getNetworkClientById' && - id === 'mainnet' - ) { - return { - provider, - configuration: { chainId: MAINNET_CHAIN_ID_HEX }, - }; - } - return undefined; - }), - publish: (event: string, ...args: unknown[]): void => { - const handler = subscribeHandlers.get(event); - if (handler) { - handler(args[0]); - } - }, - getSubscribeHandlers: (): Map void> => - subscribeHandlers, - }; - - return messenger; -} - async function withController( ...args: | [WithControllerOptions, WithControllerCallback] @@ -208,41 +114,34 @@ async function withController( const { options = {}, enabledNetworkMap = { eip155: { [MAINNET_CHAIN_ID_HEX]: true } }, - mockProvider = createMockProvider({ + mockProvider = createMockWeb3Provider({ sharesWei: '1000000000000000000', assetsWei: '1500000000000000000', }), } = controllerOptions; - const messenger = createMockMessenger(mockProvider); - messenger.call.mockImplementation((action: string, id?: string) => { - if (action === 'NetworkEnablementController:getState') { - return { enabledNetworkMap }; - } - if (action === 'NetworkController:getState') { - return { - networkConfigurationsByChainId: { - [MAINNET_CHAIN_ID_HEX]: { - chainId: MAINNET_CHAIN_ID_HEX, - rpcEndpoints: [{ networkClientId: 'mainnet' }], - defaultRpcEndpointIndex: 0, - }, - }, - networksMetadata: {}, - }; - } - if ( - action === 'NetworkController:getNetworkClientById' && - id === 'mainnet' - ) { - return { - provider: mockProvider, - configuration: { chainId: MAINNET_CHAIN_ID_HEX }, - }; - } - return undefined; + const { assetsControllerMessenger, rootMessenger } = + createMockAssetControllerMessenger(); + registerStakedMessengerActions(rootMessenger, { + enabledNetworkMap, + mockProvider, }); + // spy on staked messenger calls, so we can inspect and assert + const mockStakedMessengerCall = jest.spyOn(assetsControllerMessenger, 'call'); + + // spy on staked messenger subscriptions, so we can inspect and assert + const mockStakedMessengerSubscribe = jest.spyOn( + assetsControllerMessenger, + 'subscribe', + ); + + // spy on staked messenger unsubscribe, so we can inspect and assert + const mockStakedMessengerUnsubscribe = jest.spyOn( + assetsControllerMessenger, + 'clearEventSubscriptions', + ); + const onActiveChainsUpdated = ( options as { @@ -250,20 +149,23 @@ async function withController( } ).onActiveChainsUpdated ?? jest.fn(); - const messengerForController = - messenger as unknown as AssetsControllerMessenger; const controller = new StakedBalanceDataSource({ - messenger: messengerForController, + messenger: assetsControllerMessenger, onActiveChainsUpdated, ...options, + pollInterval: 1000, }); try { return await fn({ controller, - messenger, + messenger: assetsControllerMessenger, + mockMessengerCall: mockStakedMessengerCall, + mockMessengerSubscribe: mockStakedMessengerSubscribe, + mockMessengerUnsubscribe: mockStakedMessengerUnsubscribe, onActiveChainsUpdated, mockProvider, + rootMessenger, }); } finally { controller.destroy(); @@ -313,20 +215,20 @@ describe('StakedBalanceDataSource', () => { }); it('subscribes to transaction and network events', async () => { - await withController(({ messenger }) => { - expect(messenger.subscribe).toHaveBeenCalledWith( + await withController(({ mockMessengerSubscribe }) => { + expect(mockMessengerSubscribe).toHaveBeenCalledWith( 'TransactionController:transactionConfirmed', expect.any(Function), ); - expect(messenger.subscribe).toHaveBeenCalledWith( + expect(mockMessengerSubscribe).toHaveBeenCalledWith( 'TransactionController:incomingTransactionsReceived', expect.any(Function), ); - expect(messenger.subscribe).toHaveBeenCalledWith( + expect(mockMessengerSubscribe).toHaveBeenCalledWith( 'NetworkController:stateChange', expect.any(Function), ); - expect(messenger.subscribe).toHaveBeenCalledWith( + expect(mockMessengerSubscribe).toHaveBeenCalledWith( 'NetworkEnablementController:stateChange', expect.any(Function), ); @@ -401,32 +303,35 @@ describe('StakedBalanceDataSource', () => { }); it('returns staked balance and metadata for mainnet when fetcher returns data', async () => { - await withController(async ({ controller, messenger }) => { - const account = createMockInternalAccount(); - const request = createDataRequest({ - accounts: [account], - chainIds: [MAINNET_CHAIN_ID_CAIP], - accountsWithSupportedChains: [ - { account, supportedChains: [MAINNET_CHAIN_ID_CAIP] }, - ], - }); - const response = await controller.fetch(request); - expect(messenger.call).toHaveBeenCalledWith( - 'NetworkController:getNetworkClientById', - 'mainnet', - ); - expect(response).toBeDefined(); - expect(messenger.call).toHaveBeenCalledWith( - 'NetworkController:getNetworkClientById', - 'mainnet', - ); - }); + await withController( + async ({ controller, mockMessengerCall: mockMessengerCalls }) => { + const account = createMockInternalAccount(); + const request = createDataRequest({ + accounts: [account], + chainIds: [MAINNET_CHAIN_ID_CAIP], + accountsWithSupportedChains: [ + { account, supportedChains: [MAINNET_CHAIN_ID_CAIP] }, + ], + }); + + const response = await controller.fetch(request); + expect(response).toBeDefined(); + + expect(mockMessengerCalls).toHaveBeenCalledWith( + 'NetworkController:getNetworkClientById', + 'mainnet', + ); + }, + ); }); it('returns zero amount when getShares returns zero', async () => { await withController( { - mockProvider: createMockProvider({ sharesWei: '0', assetsWei: '0' }), + mockProvider: createMockWeb3Provider({ + sharesWei: '0', + assetsWei: '0', + }), }, async ({ controller }) => { const account = createMockInternalAccount(); @@ -501,126 +406,147 @@ describe('StakedBalanceDataSource', () => { }); describe('transaction events', () => { + const arrange = async (props: { + controller: StakedBalanceDataSource; + }): Promise => { + // subscribe and wait ensure polling finishes before we start test + const onAssetsUpdate = jest.fn(); + await props.controller.subscribe({ + request: createDataRequest(), + subscriptionId: 'test-sub', + isUpdate: false, + onAssetsUpdate, + getAssetsState: getMockAssetsState, + }); + await new Promise((resolve) => setTimeout(resolve, 100)); + onAssetsUpdate.mockClear(); + + return onAssetsUpdate; + }; + it('refreshes staked balance when transactionConfirmed involves staking contract (to)', async () => { - await withController(async ({ controller, messenger }) => { - const onAssetsUpdate = jest.fn(); - await controller.subscribe({ - request: createDataRequest(), - subscriptionId: 'test-sub', - isUpdate: false, - onAssetsUpdate, - getAssetsState: getMockAssetsState, - }); - onAssetsUpdate.mockClear(); - (messenger.publish as (e: string, p: unknown) => void)( - 'TransactionController:transactionConfirmed', - { - chainId: MAINNET_CHAIN_ID_HEX, - txParams: { to: STAKING_CONTRACT_MAINNET }, + await withController(async ({ controller, rootMessenger }) => { + // Arrange + const onAssetsUpdate = await arrange({ controller }); + + // Act + rootMessenger.publish('TransactionController:transactionConfirmed', { + id: '1', + networkClientId: 'mainnet', + status: TransactionStatus.confirmed, + time: Date.now(), + chainId: MAINNET_CHAIN_ID_HEX, + txParams: { + to: STAKING_CONTRACT_MAINNET, + from: '0x0000000000000000000000000000000000000000', }, - ); + }); + + // Assert await new Promise((resolve) => setTimeout(resolve, 300)); - expect(onAssetsUpdate.mock.calls.length).toBeGreaterThanOrEqual(0); + expect(onAssetsUpdate).toHaveBeenCalledTimes(1); }); }); it('does not refresh when transactionConfirmed does not involve staking contract', async () => { - await withController(async ({ controller, messenger }) => { - const onAssetsUpdate = jest.fn(); - await controller.subscribe({ - request: createDataRequest(), - subscriptionId: 'test-sub', - isUpdate: false, - onAssetsUpdate, - getAssetsState: getMockAssetsState, - }); - onAssetsUpdate.mockClear(); - (messenger.publish as (e: string, p: unknown) => void)( - 'TransactionController:transactionConfirmed', - { - chainId: MAINNET_CHAIN_ID_HEX, - txParams: { - from: '0xabcdef1234567890abcdef1234567890abcdef12', - to: '0x1234567890123456789012345678901234567890', - }, + await withController(async ({ controller, rootMessenger }) => { + // Arrange + const onAssetsUpdate = await arrange({ controller }); + + // Act + rootMessenger.publish('TransactionController:transactionConfirmed', { + id: '1', + networkClientId: 'mainnet', + status: TransactionStatus.confirmed, + time: Date.now(), + chainId: MAINNET_CHAIN_ID_HEX, + txParams: { + from: '0xabcdef1234567890abcdef1234567890abcdef12', + to: '0x1234567890123456789012345678901234567890', }, - ); + }); + + // Assert await new Promise((resolve) => setTimeout(resolve, 50)); expect(onAssetsUpdate).not.toHaveBeenCalled(); }); }); it('refreshes when transactionConfirmed has from equal to staking contract', async () => { - await withController(async ({ controller, messenger }) => { - const onAssetsUpdate = jest.fn(); - await controller.subscribe({ - request: createDataRequest(), - subscriptionId: 'test-sub', - isUpdate: false, - onAssetsUpdate, - getAssetsState: getMockAssetsState, + await withController(async ({ controller, rootMessenger }) => { + // Arrange + const onAssetsUpdate = await arrange({ controller }); + + // Act + rootMessenger.publish('TransactionController:transactionConfirmed', { + id: '1', + networkClientId: 'mainnet', + status: TransactionStatus.confirmed, + time: Date.now(), + chainId: MAINNET_CHAIN_ID_HEX, + txParams: { from: STAKING_CONTRACT_MAINNET.toLowerCase() }, }); - onAssetsUpdate.mockClear(); - (messenger.publish as (e: string, p: unknown) => void)( - 'TransactionController:transactionConfirmed', - { - chainId: MAINNET_CHAIN_ID_HEX, - txParams: { from: STAKING_CONTRACT_MAINNET.toLowerCase() }, - }, - ); + + // Assert await new Promise((resolve) => setTimeout(resolve, 300)); - expect(onAssetsUpdate.mock.calls.length).toBeGreaterThanOrEqual(0); + expect(onAssetsUpdate).toHaveBeenCalledTimes(1); }); }); it('refreshes when incomingTransactionsReceived includes tx involving staking contract', async () => { - await withController(async ({ controller, messenger }) => { - const onAssetsUpdate = jest.fn(); - await controller.subscribe({ - request: createDataRequest(), - subscriptionId: 'test-sub', - isUpdate: false, - onAssetsUpdate, - getAssetsState: getMockAssetsState, - }); - onAssetsUpdate.mockClear(); - (messenger.publish as (e: string, p: unknown) => void)( + await withController(async ({ controller, rootMessenger }) => { + // Arrange + const onAssetsUpdate = await arrange({ controller }); + + // Act + rootMessenger.publish( 'TransactionController:incomingTransactionsReceived', [ { + id: '1', + networkClientId: 'mainnet', + status: TransactionStatus.confirmed, + time: Date.now(), chainId: MAINNET_CHAIN_ID_HEX, - txParams: { to: STAKING_CONTRACT_MAINNET }, + txParams: { + to: STAKING_CONTRACT_MAINNET, + from: '0x0000000000000000000000000000000000000000', + }, }, ], ); + + // Assert await new Promise((resolve) => setTimeout(resolve, 300)); - expect(onAssetsUpdate.mock.calls.length).toBeGreaterThanOrEqual(0); + expect(onAssetsUpdate).toHaveBeenCalledTimes(1); }); }); it('does not refresh when incomingTransactionsReceived has no tx involving staking contract', async () => { - await withController(async ({ controller, messenger }) => { - const onAssetsUpdate = jest.fn(); - await controller.subscribe({ - request: createDataRequest(), - subscriptionId: 'test-sub', - isUpdate: false, - onAssetsUpdate, - getAssetsState: getMockAssetsState, - }); - onAssetsUpdate.mockClear(); - (messenger.publish as (e: string, p: unknown) => void)( + await withController(async ({ controller, rootMessenger }) => { + // Arrange + const onAssetsUpdate = await arrange({ controller }); + + // Act + rootMessenger.publish( 'TransactionController:incomingTransactionsReceived', [ { + id: '1', + networkClientId: 'mainnet', + status: TransactionStatus.confirmed, + time: Date.now(), chainId: MAINNET_CHAIN_ID_HEX, txParams: { to: '0x1234567890123456789012345678901234567890', + from: '0x0000000000000000000000000000000000000000', }, }, ], ); - await new Promise((resolve) => setTimeout(resolve, 50)); + + // Assert + await new Promise((resolve) => setTimeout(resolve, 100)); expect(onAssetsUpdate).not.toHaveBeenCalled(); }); }); @@ -696,39 +622,25 @@ describe('StakedBalanceDataSource', () => { describe('destroy', () => { it('unsubscribes from transaction and network events', async () => { - const unsubscribeConfirmed = jest.fn(); - const unsubscribeIncoming = jest.fn(); - const unsubscribeNetwork = jest.fn(); - const unsubscribeEnablement = jest.fn(); - const messenger = createMockMessenger(createMockProvider({})); - messenger.subscribe.mockImplementation((event: string) => { - if (event === 'TransactionController:transactionConfirmed') { - return unsubscribeConfirmed; - } - if (event === 'TransactionController:incomingTransactionsReceived') { - return unsubscribeIncoming; - } - if (event === 'NetworkController:stateChange') { - return unsubscribeNetwork; - } - if (event === 'NetworkEnablementController:stateChange') { - return unsubscribeEnablement; - } - return jest.fn(); - }); + await withController(async ({ controller, mockMessengerUnsubscribe }) => { + // Act + controller.destroy(); - const messengerForController = - messenger as unknown as AssetsControllerMessenger; - const controller = new StakedBalanceDataSource({ - messenger: messengerForController, - onActiveChainsUpdated: jest.fn(), + // Assert + expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( + 'TransactionController:transactionConfirmed', + ); + expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( + 'TransactionController:incomingTransactionsReceived', + ); + expect(mockMessengerUnsubscribe).toHaveBeenCalled(); + expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( + 'NetworkController:stateChange', + ); + expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( + 'NetworkEnablementController:stateChange', + ); }); - controller.destroy(); - - expect(unsubscribeConfirmed).toHaveBeenCalled(); - expect(unsubscribeIncoming).toHaveBeenCalled(); - expect(unsubscribeNetwork).toHaveBeenCalled(); - expect(unsubscribeEnablement).toHaveBeenCalled(); }); }); }); diff --git a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts index b8e433823ab..53d47b3253a 100644 --- a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts +++ b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts @@ -194,49 +194,29 @@ export class StakedBalanceDataSource extends AbstractDataSource< this.#handleStakedBalanceUpdate.bind(this), ); - const unsubConfirmed = this.#messenger.subscribe( + this.#messenger.subscribe( 'TransactionController:transactionConfirmed', this.#onTransactionConfirmed.bind(this), ); - this.#unsubscribeTransactionConfirmed = - typeof unsubConfirmed === 'function' ? unsubConfirmed : undefined; - const unsubIncoming = this.#messenger.subscribe( + this.#messenger.subscribe( 'TransactionController:incomingTransactionsReceived', this.#onIncomingTransactions.bind(this), ); - this.#unsubscribeIncomingTransactions = - typeof unsubIncoming === 'function' ? unsubIncoming : undefined; - const unsubNetwork = this.#messenger.subscribe( + this.#messenger.subscribe( 'NetworkController:stateChange', this.#onNetworkStateChange.bind(this), ); - this.#unsubscribeNetworkStateChange = - typeof unsubNetwork === 'function' ? unsubNetwork : undefined; - const unsubEnablement = this.#messenger.subscribe( + this.#messenger.subscribe( 'NetworkEnablementController:stateChange', this.#onNetworkEnablementControllerStateChange.bind(this), ); - this.#unsubscribeNetworkEnablementControllerStateChange = - typeof unsubEnablement === 'function' ? unsubEnablement : undefined; this.#initializeActiveChains(); } - readonly #unsubscribeTransactionConfirmed: (() => void) | undefined = - undefined; - - readonly #unsubscribeIncomingTransactions: (() => void) | undefined = - undefined; - - readonly #unsubscribeNetworkStateChange: (() => void) | undefined = undefined; - - readonly #unsubscribeNetworkEnablementControllerStateChange: - | (() => void) - | undefined = undefined; - /** * When NetworkController state changes (e.g. RPC endpoints or network clients * reconfigured), clear the provider cache so subsequent fetches use fresh @@ -919,10 +899,17 @@ export class StakedBalanceDataSource extends AbstractDataSource< * Destroy the data source and clean up all resources. */ destroy(): void { - this.#unsubscribeTransactionConfirmed?.(); - this.#unsubscribeIncomingTransactions?.(); - this.#unsubscribeNetworkStateChange?.(); - this.#unsubscribeNetworkEnablementControllerStateChange?.(); + this.#messenger.clearEventSubscriptions( + 'TransactionController:transactionConfirmed', + ); + this.#messenger.clearEventSubscriptions( + 'TransactionController:incomingTransactionsReceived', + ); + this.#messenger.clearEventSubscriptions('NetworkController:stateChange'); + this.#messenger.clearEventSubscriptions( + 'NetworkEnablementController:stateChange', + ); + for (const subscription of this.#activeSubscriptions.values()) { for (const token of subscription.pollingTokens) { this.#stakedBalanceFetcher.stopPollingByPollingToken(token); diff --git a/packages/assets-controller/src/data-sources/evm-rpc-services/services/StakedBalanceFetcher.ts b/packages/assets-controller/src/data-sources/evm-rpc-services/services/StakedBalanceFetcher.ts index f3af0f8358c..035c160b18f 100644 --- a/packages/assets-controller/src/data-sources/evm-rpc-services/services/StakedBalanceFetcher.ts +++ b/packages/assets-controller/src/data-sources/evm-rpc-services/services/StakedBalanceFetcher.ts @@ -65,7 +65,7 @@ const STAKING_CONTRACT_ABI = [ }, ]; -const STAKING_INTERFACE = new Interface(STAKING_CONTRACT_ABI); +export const STAKING_INTERFACE = new Interface(STAKING_CONTRACT_ABI); const STAKING_DECIMALS = 18; From 1e4b4ecaf9c2067788fc1ccbbcdc9e1016e03b78 Mon Sep 17 00:00:00 2001 From: Prithpal Sooriya Date: Tue, 17 Feb 2026 17:43:31 +0000 Subject: [PATCH 2/3] docs: update changelog --- packages/assets-controller/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/assets-controller/CHANGELOG.md b/packages/assets-controller/CHANGELOG.md index bd8f627ad8a..45b88bf92ff 100644 --- a/packages/assets-controller/CHANGELOG.md +++ b/packages/assets-controller/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Refactor data source tests to use shared `MockAssetControllerMessenger` fixture ([#7958](https://github.com/MetaMask/core/pull/7958)) + - Export `STAKING_INTERFACE` from the staked balance fetcher for use with the staking contract ABI. + - `StakedBalanceDataSource` teardown now uses the messenger's `clearEventSubscriptions`; custom messenger implementations must support it for correct cleanup. + ## [2.0.0] ### Added From b2c402ca4e146f8f6ed55d2687523fcd99fd9351 Mon Sep 17 00:00:00 2001 From: Prithpal Sooriya Date: Tue, 17 Feb 2026 17:46:01 +0000 Subject: [PATCH 3/3] refactor: remove unused event unsubscription logic from StakedBalanceDataSource This commit cleans up the `destroy` method in the `StakedBalanceDataSource` class by removing the unsubscription logic for transaction and network events, which is no longer necessary. Corresponding test cases have also been removed to reflect this change. --- .../StakedBalanceDataSource.test.ts | 24 ------------------- .../data-sources/StakedBalanceDataSource.ts | 11 --------- 2 files changed, 35 deletions(-) diff --git a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts index 457af9ac656..cadf81c1ae4 100644 --- a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.test.ts @@ -619,28 +619,4 @@ describe('StakedBalanceDataSource', () => { }); }); }); - - describe('destroy', () => { - it('unsubscribes from transaction and network events', async () => { - await withController(async ({ controller, mockMessengerUnsubscribe }) => { - // Act - controller.destroy(); - - // Assert - expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( - 'TransactionController:transactionConfirmed', - ); - expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( - 'TransactionController:incomingTransactionsReceived', - ); - expect(mockMessengerUnsubscribe).toHaveBeenCalled(); - expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( - 'NetworkController:stateChange', - ); - expect(mockMessengerUnsubscribe).toHaveBeenCalledWith( - 'NetworkEnablementController:stateChange', - ); - }); - }); - }); }); diff --git a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts index 53d47b3253a..87eab17512c 100644 --- a/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts +++ b/packages/assets-controller/src/data-sources/StakedBalanceDataSource.ts @@ -899,17 +899,6 @@ export class StakedBalanceDataSource extends AbstractDataSource< * Destroy the data source and clean up all resources. */ destroy(): void { - this.#messenger.clearEventSubscriptions( - 'TransactionController:transactionConfirmed', - ); - this.#messenger.clearEventSubscriptions( - 'TransactionController:incomingTransactionsReceived', - ); - this.#messenger.clearEventSubscriptions('NetworkController:stateChange'); - this.#messenger.clearEventSubscriptions( - 'NetworkEnablementController:stateChange', - ); - for (const subscription of this.#activeSubscriptions.values()) { for (const token of subscription.pollingTokens) { this.#stakedBalanceFetcher.stopPollingByPollingToken(token);