diff --git a/packages/assets-controller/CHANGELOG.md b/packages/assets-controller/CHANGELOG.md index ba1d2af7b17..f9b80fa0599 100644 --- a/packages/assets-controller/CHANGELOG.md +++ b/packages/assets-controller/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Add parallel middlewares in `ParallelMiddleware.ts`: `createParallelBalanceMiddleware` runs balance data sources (Accounts API, Snap, RPC) in parallel with chain partitioning and a fallback round for failed chains; `createParallelMiddleware` runs TokenDataSource and PriceDataSource in parallel (same request, merged response). Both use `mergeDataResponses` and limited concurrency via `p-limit` ([#7950](https://github.com/MetaMask/core/pull/7950)) +- Add `@metamask/client-controller` dependency and subscribe to `ClientController:stateChange`. Asset tracking runs only when the UI is open (ClientController) and the keyring is unlocked (KeyringController), and stops when either the UI closes or the keyring locks (Client + Keyring lifecycle) ([#7950](https://github.com/MetaMask/core/pull/7950)) +- Add full and merge update modes: `DataResponse.updateMode` and type `AssetsUpdateMode` (`'full'` | `'merge'`). Fetch uses `'full'` (response is authoritative for scope; custom assets not in response are preserved). Subscriptions could use `'merge'` or `'full'` depending on data sources. Default is `'merge'` when omitted ([#7950](https://github.com/MetaMask/core/pull/7950)) + ### Changed - Bump `@metamask/transaction-controller` from `^62.17.1` to `^62.18.0` ([#8005](https://github.com/MetaMask/core/pull/8005)) diff --git a/packages/assets-controller/package.json b/packages/assets-controller/package.json index 4b2a9343433..5b7da0b8673 100644 --- a/packages/assets-controller/package.json +++ b/packages/assets-controller/package.json @@ -55,6 +55,7 @@ "@metamask/account-tree-controller": "^4.1.1", "@metamask/assets-controllers": "^100.0.2", "@metamask/base-controller": "^9.0.0", + "@metamask/client-controller": "^1.0.0", "@metamask/controller-utils": "^11.19.0", "@metamask/core-backend": "^6.0.0", "@metamask/keyring-api": "^21.5.0", @@ -73,7 +74,8 @@ "@metamask/utils": "^11.9.0", "async-mutex": "^0.5.0", "bignumber.js": "^9.1.2", - "lodash": "^4.17.21" + "lodash": "^4.17.21", + "p-limit": "^3.1.0" }, "devDependencies": { "@metamask/auto-changelog": "^3.4.4", diff --git a/packages/assets-controller/src/AssetsController.test.ts b/packages/assets-controller/src/AssetsController.test.ts index b0ffe5a3fd6..9c84b200590 100644 --- a/packages/assets-controller/src/AssetsController.test.ts +++ b/packages/assets-controller/src/AssetsController.test.ts @@ -58,6 +58,11 @@ function createMockInternalAccount( type WithControllerOptions = { state?: Partial; isBasicFunctionality?: () => boolean; + /** + * When set, registers ClientController:getState so the controller sees this UI state. + * Required for tests that rely on asset tracking running (e.g. trackMetaMetricsEvent on unlock). + */ + clientControllerState?: { isUiOpen: boolean }; /** Extra options passed to AssetsController constructor (e.g. trackMetaMetricsEvent). */ controllerOptions?: Partial<{ trackMetaMetricsEvent: ( @@ -91,6 +96,7 @@ async function withController( { state = {}, isBasicFunctionality = (): boolean => true, + clientControllerState, controllerOptions = {}, }, fn, @@ -144,6 +150,17 @@ async function withController( tokensChainsCache: {}, })); + if (clientControllerState !== undefined) { + ( + messenger as { + registerActionHandler: (a: string, h: () => unknown) => void; + } + ).registerActionHandler( + 'ClientController:getState', + () => clientControllerState, + ); + } + const controller = new AssetsController({ messenger: messenger as unknown as AssetsControllerMessenger, state, @@ -814,8 +831,15 @@ describe('AssetsController', () => { const trackMetaMetricsEvent = jest.fn(); await withController( - { controllerOptions: { trackMetaMetricsEvent } }, + { + clientControllerState: { isUiOpen: true }, + controllerOptions: { trackMetaMetricsEvent }, + }, async ({ messenger }) => { + // UI must be open and keyring unlocked for asset tracking to run + ( + messenger as { publish: (topic: string, payload?: unknown) => void } + ).publish('ClientController:stateChange', { isUiOpen: true }); messenger.publish('KeyringController:unlock'); // Allow #start() -> getAssets() to resolve so the callback runs @@ -842,8 +866,15 @@ describe('AssetsController', () => { const trackMetaMetricsEvent = jest.fn(); await withController( - { controllerOptions: { trackMetaMetricsEvent } }, + { + clientControllerState: { isUiOpen: true }, + controllerOptions: { trackMetaMetricsEvent }, + }, async ({ messenger }) => { + // UI must be open and keyring unlocked for asset tracking to run + ( + messenger as { publish: (topic: string, payload?: unknown) => void } + ).publish('ClientController:stateChange', { isUiOpen: true }); messenger.publish('KeyringController:unlock'); await new Promise((resolve) => setTimeout(resolve, 100)); diff --git a/packages/assets-controller/src/AssetsController.ts b/packages/assets-controller/src/AssetsController.ts index ba421627e37..6cd458096fb 100644 --- a/packages/assets-controller/src/AssetsController.ts +++ b/packages/assets-controller/src/AssetsController.ts @@ -9,6 +9,8 @@ import type { ControllerStateChangeEvent, StateMetadata, } from '@metamask/base-controller'; +import type { ClientControllerStateChangeEvent } from '@metamask/client-controller'; +import { clientControllerSelectors } from '@metamask/client-controller'; import type { ApiPlatformClient, BackendWebSocketServiceActions, @@ -73,9 +75,14 @@ import { StakedBalanceDataSource } from './data-sources/StakedBalanceDataSource' import { TokenDataSource } from './data-sources/TokenDataSource'; import { projectLogger, createModuleLogger } from './logger'; import { DetectionMiddleware } from './middlewares/DetectionMiddleware'; +import { + createParallelBalanceMiddleware, + createParallelMiddleware, +} from './middlewares/ParallelMiddleware'; import type { AccountId, AssetPreferences, + AssetsUpdateMode, ChainId, Caip19AssetId, AssetMetadata, @@ -225,6 +232,7 @@ type AllowedActions = type AllowedEvents = // AssetsController | AccountTreeControllerSelectedAccountGroupChangeEvent + | ClientControllerStateChangeEvent | KeyringControllerLockEvent | KeyringControllerUnlockEvent | PreferencesControllerStateChangeEvent @@ -418,6 +426,10 @@ function normalizeResponse(response: DataResponse): DataResponse { normalized.errors = { ...response.errors }; } + if (response.updateMode) { + normalized.updateMode = response.updateMode; + } + return normalized; } @@ -441,8 +453,10 @@ function normalizeResponse(response: DataResponse): DataResponse { * based on which chains they support. When active chains change, the controller * dynamically adjusts subscriptions. * - * 4. **Keyring Lifecycle**: Listens to KeyringController unlock/lock events to - * start/stop subscriptions when the wallet is unlocked or locked. + * 4. **Client + Keyring Lifecycle**: Starts subscriptions only when both the UI is + * open (ClientController) and the wallet is unlocked (KeyringController). + * Stops when either the UI closes or the keyring locks. See client-controller + * README for the combined pattern. * * ## Architecture * @@ -472,6 +486,12 @@ export class AssetsController extends BaseController< /** Whether we have already reported first init fetch for this session (reset on #stop). */ #firstInitFetchReported = false; + /** Whether the client (UI) is open. Combined with #keyringUnlocked for #updateActive. */ + #uiOpen = false; + + /** Whether the keyring is unlocked. Combined with #uiOpen for #updateActive. */ + #keyringUnlocked = false; + readonly #controllerMutex = new Mutex(); /** @@ -621,7 +641,7 @@ export class AssetsController extends BaseController< this.#initializeState(); this.#subscribeToEvents(); this.#registerActionHandlers(); - // Subscriptions start only on KeyringController:unlock -> #start(), not here. + // Subscriptions start only when both UI is open and keyring unlocked -> #updateActive(). // Subscribe to basic-functionality changes after construction so a synchronous // onChange during subscribe cannot run before data sources are initialized. @@ -722,9 +742,36 @@ export class AssetsController extends BaseController< }, ); - // Keyring lifecycle: start when unlocked, stop when locked - this.messenger.subscribe('KeyringController:unlock', () => this.#start()); - this.messenger.subscribe('KeyringController:lock', () => this.#stop()); + // Client + Keyring lifecycle: only run when UI is open AND keyring is unlocked + this.messenger.subscribe( + 'ClientController:stateChange', + (isUiOpen: boolean) => { + this.#uiOpen = isUiOpen; + this.#updateActive(); + }, + clientControllerSelectors.selectIsUiOpen, + ); + this.messenger.subscribe('KeyringController:unlock', () => { + this.#keyringUnlocked = true; + this.#updateActive(); + }); + this.messenger.subscribe('KeyringController:lock', () => { + this.#keyringUnlocked = false; + this.#updateActive(); + }); + } + + /** + * Start or stop asset tracking based on client (UI) open state and keyring + * unlock state. Only runs when both UI is open and keyring is unlocked. + */ + #updateActive(): void { + const shouldRun = this.#uiOpen && this.#keyringUnlocked; + if (shouldRun) { + this.#start(); + } else { + this.#stop(); + } } #registerActionHandlers(): void { @@ -890,6 +937,10 @@ export class AssetsController extends BaseController< const assetTypes = options?.assetTypes ?? ['fungible']; const dataTypes = options?.dataTypes ?? ['balance', 'metadata', 'price']; + if (accounts.length === 0 || chainIds.length === 0) { + return this.#getAssetsFromState(accounts, chainIds, assetTypes); + } + // Collect custom assets for all requested accounts const customAssets: Caip19AssetId[] = []; for (const account of accounts) { @@ -907,13 +958,17 @@ export class AssetsController extends BaseController< }); const sources = this.#isBasicFunctionality() ? [ - this.#accountsApiDataSource, - this.#snapDataSource, - this.#rpcDataSource, - this.#stakedBalanceDataSource, + createParallelBalanceMiddleware([ + this.#accountsApiDataSource, + this.#snapDataSource, + this.#rpcDataSource, + this.#stakedBalanceDataSource, + ]), this.#detectionMiddleware, - this.#tokenDataSource, - this.#priceDataSource, + createParallelMiddleware([ + this.#tokenDataSource, + this.#priceDataSource, + ]), ] : [ this.#rpcDataSource, @@ -924,7 +979,7 @@ export class AssetsController extends BaseController< sources, request, ); - await this.#updateState(response); + await this.#updateState({ ...response, updateMode: 'full' }); if (this.#trackMetaMetricsEvent && !this.#firstInitFetchReported) { this.#firstInitFetchReported = true; const durationMs = Date.now() - startTime; @@ -1199,8 +1254,8 @@ export class AssetsController extends BaseController< // ============================================================================ async #updateState(response: DataResponse): Promise { - // Normalize asset IDs (checksum EVM addresses) before storing in state const normalizedResponse = normalizeResponse(response); + const mode: AssetsUpdateMode = normalizedResponse.updateMode ?? 'merge'; const releaseLock = await this.#controllerMutex.acquire(); @@ -1248,20 +1303,35 @@ export class AssetsController extends BaseController< )) { const previousBalances = previousState.assetsBalance[accountId] ?? {}; - - if (!balances[accountId]) { - balances[accountId] = {}; - } - - for (const [assetId, balance] of Object.entries(accountBalances)) { + const customAssetIds = + (state.customAssets as Record)[ + accountId + ] ?? []; + + // Full: response is authoritative; preserve custom assets not in response. Merge: response overlays previous. + const effective: Record = + mode === 'full' + ? ((): Record => { + const next: Record = { + ...accountBalances, + }; + for (const customId of customAssetIds) { + if (!(customId in next)) { + const prev = previousBalances[customId]; + next[customId] = + prev ?? ({ amount: '0' } as AssetBalance); + } + } + return next; + })() + : { ...previousBalances, ...accountBalances }; + + for (const [assetId, balance] of Object.entries(effective)) { const previousBalance = previousBalances[ assetId as Caip19AssetId ] as { amount: string } | undefined; - const balanceData = balance as { amount: string }; - const newAmount = balanceData.amount; + const newAmount = (balance as { amount: string }).amount; const oldAmount = previousBalance?.amount; - - // Track if balance actually changed if (oldAmount !== newAmount) { changedBalances.push({ accountId, @@ -1271,8 +1341,7 @@ export class AssetsController extends BaseController< }); } } - - Object.assign(balances[accountId], accountBalances); + balances[accountId] = effective; } } @@ -1537,7 +1606,7 @@ export class AssetsController extends BaseController< * Subscribe to asset updates for all selected accounts. */ #subscribeAssets(): void { - if (this.#selectedAccounts.length === 0) { + if (this.#selectedAccounts.length === 0 || this.#enabledChains.size === 0) { return; } @@ -1909,10 +1978,16 @@ export class AssetsController extends BaseController< hasPrice: Boolean(response.assetsPrice), }); - // Run through enrichment middlewares (Event Stack: Detection → Token → Price) + // Run through enrichment middlewares (Detection, then Token + Price in parallel) // Include 'metadata' in dataTypes so TokenDataSource runs to enrich detected assets const { response: enrichedResponse } = await this.#executeMiddlewares( - [this.#detectionMiddleware, this.#tokenDataSource, this.#priceDataSource], + [ + this.#detectionMiddleware, + createParallelMiddleware([ + this.#tokenDataSource, + this.#priceDataSource, + ]), + ], request ?? { accountsWithSupportedChains: [], chainIds: [], diff --git a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts index a0846e390eb..69a48da275b 100644 --- a/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts +++ b/packages/assets-controller/src/data-sources/AccountsApiDataSource.ts @@ -325,6 +325,7 @@ export class AccountsApiDataSource extends AbstractDataSource< ); response.assetsBalance = assetsBalance; + response.updateMode = 'full'; } catch (error) { log('Fetch FAILED', { error, chains: chainsToFetch }); diff --git a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts index 6d38ac55bb7..9d6c617109d 100644 --- a/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts +++ b/packages/assets-controller/src/data-sources/BackendWebsocketDataSource.ts @@ -624,7 +624,7 @@ export class BackendWebsocketDataSource extends AbstractDataSource< }; } - const response: DataResponse = {}; + const response: DataResponse = { updateMode: 'merge' }; if (Object.keys(assetsBalance[accountId]).length > 0) { response.assetsBalance = assetsBalance; response.assetsInfo = assetsMetadata; diff --git a/packages/assets-controller/src/data-sources/PriceDataSource.ts b/packages/assets-controller/src/data-sources/PriceDataSource.ts index 3d635acf440..2560ed049a4 100644 --- a/packages/assets-controller/src/data-sources/PriceDataSource.ts +++ b/packages/assets-controller/src/data-sources/PriceDataSource.ts @@ -393,7 +393,10 @@ export class PriceDataSource { fetchResponse.assetsPrice && Object.keys(fetchResponse.assetsPrice).length > 0 ) { - await subscription.onAssetsUpdate(fetchResponse); + await subscription.onAssetsUpdate({ + ...fetchResponse, + updateMode: 'merge', + }); } } catch (error) { log('Subscription poll failed', { subscriptionId, error }); diff --git a/packages/assets-controller/src/data-sources/RpcDataSource.ts b/packages/assets-controller/src/data-sources/RpcDataSource.ts index 922f0bb40d8..78816a45c58 100644 --- a/packages/assets-controller/src/data-sources/RpcDataSource.ts +++ b/packages/assets-controller/src/data-sources/RpcDataSource.ts @@ -411,6 +411,7 @@ export class RpcDataSource extends AbstractDataSource< [result.accountId]: newBalances, }, assetsInfo, + updateMode: 'full', }; log('Balance update response', { @@ -483,6 +484,7 @@ export class RpcDataSource extends AbstractDataSource< assetsBalance: { [result.accountId]: newBalances, }, + updateMode: 'full', }; for (const subscription of this.#activeSubscriptions.values()) { diff --git a/packages/assets-controller/src/data-sources/SnapDataSource.test.ts b/packages/assets-controller/src/data-sources/SnapDataSource.test.ts index 561cfd440da..180c8e1ed8b 100644 --- a/packages/assets-controller/src/data-sources/SnapDataSource.test.ts +++ b/packages/assets-controller/src/data-sources/SnapDataSource.test.ts @@ -435,6 +435,7 @@ describe('SnapDataSource', () => { expect(response).toStrictEqual({ assetsBalance: {}, assetsInfo: {}, + updateMode: 'full', }); cleanup(); @@ -469,6 +470,7 @@ describe('SnapDataSource', () => { expect(response).toStrictEqual({ assetsBalance: {}, assetsInfo: {}, + updateMode: 'full', }); cleanup(); diff --git a/packages/assets-controller/src/data-sources/SnapDataSource.ts b/packages/assets-controller/src/data-sources/SnapDataSource.ts index 823daae02bd..b6e7adc8b41 100644 --- a/packages/assets-controller/src/data-sources/SnapDataSource.ts +++ b/packages/assets-controller/src/data-sources/SnapDataSource.ts @@ -298,7 +298,7 @@ export class SnapDataSource extends AbstractDataSource< // Only report if we have snap-related updates if (assetsBalance) { - const response: DataResponse = { assetsBalance }; + const response: DataResponse = { assetsBalance, updateMode: 'merge' }; for (const subscription of this.activeSubscriptions.values()) { subscription.onAssetsUpdate(response)?.catch(console.error); } @@ -439,12 +439,13 @@ export class SnapDataSource extends AbstractDataSource< return {}; } if (!request?.accountsWithSupportedChains?.length) { - return { assetsBalance: {}, assetsInfo: {} }; + return { assetsBalance: {}, assetsInfo: {}, updateMode: 'full' }; } const results: DataResponse = { assetsBalance: {}, assetsInfo: {}, + updateMode: 'full', }; // Fetch balances for each account using its snap ID from metadata diff --git a/packages/assets-controller/src/index.ts b/packages/assets-controller/src/index.ts index be2857c1b4c..ac43554d6e9 100644 --- a/packages/assets-controller/src/index.ts +++ b/packages/assets-controller/src/index.ts @@ -66,6 +66,7 @@ export type { DataType, DataRequest, DataResponse, + AssetsUpdateMode, // Middleware types Context, NextFunction, diff --git a/packages/assets-controller/src/middlewares/ParallelMiddleware.test.ts b/packages/assets-controller/src/middlewares/ParallelMiddleware.test.ts new file mode 100644 index 00000000000..c6e5811065f --- /dev/null +++ b/packages/assets-controller/src/middlewares/ParallelMiddleware.test.ts @@ -0,0 +1,126 @@ +import { createParallelMiddleware } from './ParallelMiddleware'; +import type { TokenPriceSource } from './ParallelMiddleware'; +import type { Context, DataResponse } from '../types'; + +const MOCK_ASSET = 'eip155:1/erc20:0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + +function createMockContext(overrides?: Partial): Context { + return { + request: { + chainIds: ['eip155:1'], + accountsWithSupportedChains: [], + dataTypes: ['balance', 'metadata', 'price'], + }, + response: {}, + getAssetsState: jest.fn().mockReturnValue({ + assetsInfo: {}, + assetsBalance: {}, + customAssets: {}, + }), + ...overrides, + }; +} + +function createMockSource( + name: string, + response: DataResponse, +): TokenPriceSource { + return { + getName: () => name, + assetsMiddleware: async (ctx, next): Promise => { + return next({ + ...ctx, + response: { ...ctx.response, ...response }, + }); + }, + }; +} + +describe('createParallelMiddleware', () => { + describe('getName', () => { + it('returns ParallelMiddleware', () => { + const middleware = createParallelMiddleware([]); + expect(middleware.getName()).toBe('ParallelMiddleware'); + }); + }); + + describe('assetsMiddleware', () => { + it('calls next with same context when sources array is empty', async () => { + const middleware = createParallelMiddleware([]); + const context = createMockContext(); + const next = jest.fn().mockResolvedValue(context); + + await middleware.assetsMiddleware(context, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(next).toHaveBeenCalledWith(context); + }); + + it('runs multiple sources in parallel and merges responses', async () => { + const tokenSource = createMockSource('TokenSource', { + assetsInfo: { + [MOCK_ASSET]: { + type: 'erc20', + symbol: 'USDC', + name: 'USD Coin', + decimals: 6, + }, + }, + }); + const priceSource = createMockSource('PriceSource', { + assetsPrice: { + [MOCK_ASSET]: { + price: 1.0, + lastUpdated: Date.now(), + }, + }, + }); + + const middleware = createParallelMiddleware([tokenSource, priceSource]); + const context = createMockContext(); + const next = jest.fn().mockImplementation((ctx: Context) => { + return Promise.resolve(ctx); + }); + + const result = await middleware.assetsMiddleware(context, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(result.response.assetsInfo).toHaveProperty(MOCK_ASSET); + expect(result.response.assetsInfo?.[MOCK_ASSET]).toMatchObject({ + symbol: 'USDC', + name: 'USD Coin', + decimals: 6, + }); + expect(result.response.assetsPrice).toHaveProperty(MOCK_ASSET); + expect(result.response.assetsPrice?.[MOCK_ASSET]).toMatchObject({ + price: 1.0, + }); + }); + + it('merges with existing context.response', async () => { + const source = createMockSource('Single', { + assetsInfo: { + [MOCK_ASSET]: { type: 'erc20', symbol: 'T', name: 'T', decimals: 18 }, + }, + }); + const middleware = createParallelMiddleware([source]); + const context = createMockContext({ + response: { + assetsBalance: { + 'account-1': { [MOCK_ASSET]: { balance: '100' as `${number}` } }, + }, + }, + }); + const next = jest + .fn() + .mockImplementation((ctx: Context) => Promise.resolve(ctx)); + + const result = await middleware.assetsMiddleware(context, next); + + expect(result.response.assetsBalance).toStrictEqual( + context.response.assetsBalance, + ); + expect(result.response.assetsInfo).toHaveProperty(MOCK_ASSET); + }); + }); +}); diff --git a/packages/assets-controller/src/middlewares/ParallelMiddleware.ts b/packages/assets-controller/src/middlewares/ParallelMiddleware.ts new file mode 100644 index 00000000000..e3615530378 --- /dev/null +++ b/packages/assets-controller/src/middlewares/ParallelMiddleware.ts @@ -0,0 +1,309 @@ +import pLimit from 'p-limit'; + +import type { + ChainId, + Context, + DataRequest, + DataResponse, + Middleware, +} from '../types'; + +// ============================================================================ +// MERGE HELPER +// ============================================================================ + +/** + * Deep-merge multiple DataResponses into one. + * Used when running balance data sources in parallel. + * + * @param responses - Array of DataResponse from each source. + * @returns Single merged DataResponse. + */ +export function mergeDataResponses(responses: DataResponse[]): DataResponse { + const merged: DataResponse = {}; + + for (const response of responses) { + if (response.assetsBalance) { + merged.assetsBalance ??= {}; + for (const [accountId, accountBalances] of Object.entries( + response.assetsBalance, + )) { + merged.assetsBalance[accountId] = { + ...(merged.assetsBalance[accountId] ?? {}), + ...accountBalances, + }; + } + } + if (response.assetsInfo) { + merged.assetsInfo = { + ...(merged.assetsInfo ?? {}), + ...response.assetsInfo, + }; + } + if (response.assetsPrice) { + merged.assetsPrice = { + ...(merged.assetsPrice ?? {}), + ...response.assetsPrice, + }; + } + if (response.errors) { + merged.errors = { + ...(merged.errors ?? {}), + ...response.errors, + }; + } + if (response.detectedAssets) { + merged.detectedAssets = { + ...(merged.detectedAssets ?? {}), + ...response.detectedAssets, + }; + } + if (response.updateMode === 'full') { + merged.updateMode = 'full'; + } + } + merged.updateMode ??= 'merge'; + + return merged; +} + +// ============================================================================ +// PARALLEL BALANCE MIDDLEWARE +// ============================================================================ + +const PARALLEL_BALANCE_MIDDLEWARE_NAME = 'ParallelBalanceMiddleware'; + +/** Max concurrent balance source calls (round 1 and fallback). */ +const BALANCE_CONCURRENCY = 3; + +export type BalanceSource = { + getName(): string; + /** Chains this source can fetch (e.g. from getActiveChainsSync()). Used to partition chains with no overlap. */ + getActiveChainsSync(): ChainId[]; + assetsMiddleware: Middleware; +}; + +/** + * Partition request.chainIds so each chain is assigned to exactly one source + * (by source order: first source that supports the chain gets it). Ensures no + * chain overlap across data source calls. + * + * @param request - The data request with chainIds to partition. + * @param sources - Balance sources in priority order (e.g. AccountsAPI, Snap, Rpc). + * @returns Array of requests, one per source, each with only that source's assigned chainIds. + */ +function partitionChainsBySource( + request: DataRequest, + sources: BalanceSource[], +): DataRequest[] { + const { chainIds } = request; + const assigned = new Set(); + + return sources.map((source) => { + const supported = new Set(source.getActiveChainsSync()); + const chainsForSource = chainIds.filter( + (id) => supported.has(id) && !assigned.has(id), + ); + chainsForSource.forEach((id) => assigned.add(id)); + + return { + ...request, + chainIds: chainsForSource, + }; + }); +} + +/** + * Collect chain IDs that failed in the first round (present in response.errors). + * Used to run a fallback round with remaining sources. + * + * @param requests - Partitioned requests, one per source (same order as results). + * @param results - Results from each source; chain IDs in requests[i] that have errors in results[i].response.errors are considered failed. + * @returns Set of chain IDs that had errors in the first round. + */ +function getFailedChainIds( + requests: DataRequest[], + results: { response: DataResponse }[], +): Set { + const failed = new Set(); + for (let i = 0; i < results.length; i++) { + const errors = results[i].response.errors ?? {}; + for (const chainId of requests[i].chainIds) { + if (errors[chainId]) { + failed.add(chainId); + } + } + } + return failed; +} + +/** + * Middleware that runs multiple balance data source middlewares in parallel, + * with no chain overlap. Chains that fail (response.errors) are re-partitioned + * and fetched again in a fallback round so lower-priority sources can try them. + * + * @param sources - Array of balance sources in priority order (each with getName(), getActiveChainsSync(), assetsMiddleware). + * @returns A single middleware that runs all sources in parallel and merges responses. + */ +export function createParallelBalanceMiddleware(sources: BalanceSource[]): { + getName(): string; + assetsMiddleware: Middleware; +} { + return { + getName(): string { + return PARALLEL_BALANCE_MIDDLEWARE_NAME; + }, + + assetsMiddleware: async (context, next): Promise => { + if (sources.length === 0) { + return next(context); + } + + const noopNext = async (ctx: typeof context): Promise => + ctx; + const limit = pLimit(BALANCE_CONCURRENCY); + + // Round 1: partition chains (no overlap), run with limited concurrency + const requests = partitionChainsBySource(context.request, sources); + const results = await Promise.all( + sources.map((source, i) => + limit(() => + source.assetsMiddleware( + { + request: requests[i], + response: {}, + getAssetsState: context.getAssetsState, + }, + noopNext, + ), + ), + ), + ); + + let mergedResponse = mergeDataResponses( + results.map((result) => result.response), + ); + + // Fallback: chains that failed (in errors) get re-partitioned and tried again + const failedChainIds = getFailedChainIds(requests, results); + if (failedChainIds.size > 0) { + const fallbackRequest: DataRequest = { + ...context.request, + chainIds: [...failedChainIds], + }; + const fallbackRequests = partitionChainsBySource( + fallbackRequest, + sources, + ); + const fallbackResults = await Promise.all( + sources.map((source, i) => + limit(() => + source.assetsMiddleware( + { + request: fallbackRequests[i], + response: {}, + getAssetsState: context.getAssetsState, + }, + noopNext, + ), + ), + ), + ); + const fallbackMerged = mergeDataResponses( + fallbackResults.map((result) => result.response), + ); + mergedResponse = mergeDataResponses([mergedResponse, fallbackMerged]); + // Remove errors for chains we successfully got balance for in fallback + if (mergedResponse.errors && mergedResponse.assetsBalance) { + const chainsWithBalance = new Set(); + for (const accountBalances of Object.values( + mergedResponse.assetsBalance, + )) { + for (const assetId of Object.keys(accountBalances)) { + const chainId = assetId.split('/')[0] as ChainId; + chainsWithBalance.add(chainId); + } + } + for (const chainId of failedChainIds) { + if (chainsWithBalance.has(chainId)) { + delete mergedResponse.errors[chainId]; + } + } + } + } + + return next({ + ...context, + response: mergeDataResponses([context.response, mergedResponse]), + }); + }, + }; +} + +// ============================================================================ +// PARALLEL TOKEN/PRICE MIDDLEWARE +// ============================================================================ + +const PARALLEL_MIDDLEWARE_NAME = 'ParallelMiddleware'; + +/** Max concurrent token/price source calls. */ +const CONCURRENCY = 2; + +export type TokenPriceSource = { + getName(): string; + assetsMiddleware: Middleware; +}; + +/** + * Middleware that runs multiple data source middlewares (e.g. TokenDataSource, + * PriceDataSource) in parallel with the same request. Responses are merged so + * that assetsInfo (token metadata) and assetsPrice are combined. Use this to + * fetch token and price data concurrently instead of sequentially. + * + * @param sources - Array of sources with getName() and assetsMiddleware. + * @returns A single middleware that runs all sources in parallel and merges responses. + */ +export function createParallelMiddleware(sources: TokenPriceSource[]): { + getName(): string; + assetsMiddleware: Middleware; +} { + return { + getName(): string { + return PARALLEL_MIDDLEWARE_NAME; + }, + + assetsMiddleware: async (context, next): Promise => { + if (sources.length === 0) { + return next(context); + } + + const noopNext = async (ctx: typeof context): Promise => + ctx; + const limit = pLimit(CONCURRENCY); + + const results = await Promise.all( + sources.map((source) => + limit(() => + source.assetsMiddleware( + { + request: context.request, + response: { ...context.response }, + getAssetsState: context.getAssetsState, + }, + noopNext, + ), + ), + ), + ); + + const mergedResponse = mergeDataResponses( + results.map((result) => result.response), + ); + + return next({ + ...context, + response: mergeDataResponses([context.response, mergedResponse]), + }); + }, + }; +} diff --git a/packages/assets-controller/src/middlewares/index.ts b/packages/assets-controller/src/middlewares/index.ts index 2c54fa8b313..33346a734db 100644 --- a/packages/assets-controller/src/middlewares/index.ts +++ b/packages/assets-controller/src/middlewares/index.ts @@ -1 +1,7 @@ export { DetectionMiddleware } from './DetectionMiddleware'; +export { + createParallelBalanceMiddleware, + createParallelMiddleware, + mergeDataResponses, +} from './ParallelMiddleware'; +export type { BalanceSource, TokenPriceSource } from './ParallelMiddleware'; diff --git a/packages/assets-controller/src/types.ts b/packages/assets-controller/src/types.ts index 5365db54c1c..8a63f91e1e8 100644 --- a/packages/assets-controller/src/types.ts +++ b/packages/assets-controller/src/types.ts @@ -360,8 +360,23 @@ export type DataResponse = { errors?: Record; /** Detected assets (assets that do not have metadata) */ detectedAssets?: Record; + /** + * How to apply this response to state. See {@link AssetsUpdateMode}. + * Defaults to `'merge'` if omitted. + */ + updateMode?: AssetsUpdateMode; }; +/** + * Type of {@link DataResponse.updateMode}: how the controller applies the response to state. + * + * - **full**: Response is the full set for the scope. Assets in state but not in the + * response are cleared (except custom assets). Use for initial fetch or full refresh. + * - **merge**: Only assets present in the response are updated; nothing is removed. + * Use for event-driven or incremental updates. + */ +export type AssetsUpdateMode = 'full' | 'merge'; + // ============================================================================ // DATA SOURCE <-> CONTROLLER (DIRECT CALLS, NO MESSENGER PER SOURCE) // ============================================================================ diff --git a/packages/assets-controller/tsconfig.build.json b/packages/assets-controller/tsconfig.build.json index 3dd2cab2e64..e826cd3f491 100644 --- a/packages/assets-controller/tsconfig.build.json +++ b/packages/assets-controller/tsconfig.build.json @@ -7,14 +7,15 @@ }, "references": [ { "path": "../account-tree-controller/tsconfig.build.json" }, + { "path": "../assets-controllers/tsconfig.build.json" }, { "path": "../base-controller/tsconfig.build.json" }, + { "path": "../client-controller/tsconfig.build.json" }, { "path": "../core-backend/tsconfig.build.json" }, { "path": "../keyring-controller/tsconfig.build.json" }, { "path": "../messenger/tsconfig.build.json" }, { "path": "../network-enablement-controller/tsconfig.build.json" }, { "path": "../permission-controller/tsconfig.build.json" }, - { "path": "../preferences-controller/tsconfig.build.json" }, - { "path": "../assets-controllers/tsconfig.build.json" } + { "path": "../preferences-controller/tsconfig.build.json" } ], "include": ["../../types", "./src"], "exclude": ["**/*.test.ts", "**/__fixtures__/"] diff --git a/packages/assets-controller/tsconfig.json b/packages/assets-controller/tsconfig.json index 5fa36386931..0b58b6c2916 100644 --- a/packages/assets-controller/tsconfig.json +++ b/packages/assets-controller/tsconfig.json @@ -5,13 +5,14 @@ }, "references": [ { "path": "../account-tree-controller" }, + { "path": "../assets-controllers" }, { "path": "../base-controller" }, + { "path": "../client-controller" }, { "path": "../core-backend" }, { "path": "../keyring-controller" }, { "path": "../messenger" }, { "path": "../network-enablement-controller" }, - { "path": "../preferences-controller" }, - { "path": "../assets-controllers" } + { "path": "../preferences-controller" } ], "include": ["../../types", "./src"] } diff --git a/yarn.lock b/yarn.lock index 2089980d930..9bc1d2007ca 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2519,6 +2519,7 @@ __metadata: "@metamask/assets-controllers": "npm:^100.0.2" "@metamask/auto-changelog": "npm:^3.4.4" "@metamask/base-controller": "npm:^9.0.0" + "@metamask/client-controller": "npm:^1.0.0" "@metamask/controller-utils": "npm:^11.19.0" "@metamask/core-backend": "npm:^6.0.0" "@metamask/keyring-api": "npm:^21.5.0" @@ -2543,6 +2544,7 @@ __metadata: deepmerge: "npm:^4.2.2" jest: "npm:^29.7.0" lodash: "npm:^4.17.21" + p-limit: "npm:^3.1.0" ts-jest: "npm:^29.2.5" tsx: "npm:^4.20.5" typedoc: "npm:^0.25.13" @@ -2861,7 +2863,7 @@ __metadata: languageName: unknown linkType: soft -"@metamask/client-controller@workspace:packages/client-controller": +"@metamask/client-controller@npm:^1.0.0, @metamask/client-controller@workspace:packages/client-controller": version: 0.0.0-use.local resolution: "@metamask/client-controller@workspace:packages/client-controller" dependencies: