diff --git a/packages/snaps-controllers/src/multichain/MultichainRoutingController.test.ts b/packages/snaps-controllers/src/multichain/MultichainRoutingController.test.ts index 77fd54e787..1950f52b8b 100644 --- a/packages/snaps-controllers/src/multichain/MultichainRoutingController.test.ts +++ b/packages/snaps-controllers/src/multichain/MultichainRoutingController.test.ts @@ -40,8 +40,7 @@ describe('MultichainRoutingController', () => { rootMessenger.registerActionHandler( 'SnapController:handleRequest', async ({ handler }) => { - // TODO: Use proper handler - if (handler === HandlerType.OnProtocolRequest) { + if (handler === HandlerType.OnKeyringRequest) { return null; } throw new Error('Unmocked request'); @@ -97,8 +96,7 @@ describe('MultichainRoutingController', () => { rootMessenger.registerActionHandler( 'SnapController:handleRequest', async ({ handler }) => { - // TODO: Use proper handler - if (handler === HandlerType.OnProtocolRequest) { + if (handler === HandlerType.OnKeyringRequest) { return { address: SOLANA_CONNECTED_ACCOUNTS[0] }; } throw new Error('Unmocked request'); diff --git a/packages/snaps-controllers/src/multichain/MultichainRoutingController.ts b/packages/snaps-controllers/src/multichain/MultichainRoutingController.ts index 0e3205a1be..c759db8deb 100644 --- a/packages/snaps-controllers/src/multichain/MultichainRoutingController.ts +++ b/packages/snaps-controllers/src/multichain/MultichainRoutingController.ts @@ -133,19 +133,20 @@ export class MultichainRoutingController extends BaseController< request: JsonRpcRequest, ) { try { + // TODO: Decide if we should call this using another abstraction. const result = (await this.messagingSystem.call( 'SnapController:handleRequest', { snapId, origin: 'metamask', request: { - method: '', + method: 'keyring_resolveAccountAddress', params: { scope, request, }, }, - handler: HandlerType.OnProtocolRequest, // TODO: Export and request format + handler: HandlerType.OnKeyringRequest, }, )) as { address: CaipAccountId } | null; const address = result?.address; @@ -233,6 +234,7 @@ export class MultichainRoutingController extends BaseController< async handleRequest({ connectedAddresses, + origin, scope, request, }: { @@ -272,10 +274,13 @@ export class MultichainRoutingController extends BaseController< if (protocolSnap) { return this.messagingSystem.call('SnapController:handleRequest', { snapId: protocolSnap.snapId, - origin: 'metamask', // TODO: Determine origin of these requests? + origin: 'metamask', request: { method: '', params: { + // We are overriding the origin here, so that the Snap gets the proper origin + // while the permissions check is skipped due to the requesting origin being metamask. + origin, request, scope, }, diff --git a/packages/snaps-execution-environments/src/common/commands.ts b/packages/snaps-execution-environments/src/common/commands.ts index 222e5caa13..0a5a965f0e 100644 --- a/packages/snaps-execution-environments/src/common/commands.ts +++ b/packages/snaps-execution-environments/src/common/commands.ts @@ -15,6 +15,7 @@ import { assertIsOnSignatureRequestArguments, assertIsOnNameLookupRequestArguments, assertIsOnUserInputRequestArguments, + assertIsOnProtocolRequestArguments, } from './validation'; export type CommandMethodsMapping = { @@ -74,9 +75,21 @@ export function getHandlerArguments( address, }; } + + case HandlerType.OnProtocolRequest: { + assertIsOnProtocolRequestArguments(request.params); + + // For this specific handler we extract the origin from the parameters. + const { + origin: nestedOrigin, + request: nestedRequest, + scope, + } = request.params; + return { origin: nestedOrigin, request: nestedRequest, scope }; + } + case HandlerType.OnRpcRequest: case HandlerType.OnKeyringRequest: - case HandlerType.OnProtocolRequest: // TODO: Decide on origin return { origin, request }; case HandlerType.OnCronjob: diff --git a/packages/snaps-execution-environments/src/common/validation.ts b/packages/snaps-execution-environments/src/common/validation.ts index 166c9df8c7..0fabe6d4cb 100644 --- a/packages/snaps-execution-environments/src/common/validation.ts +++ b/packages/snaps-execution-environments/src/common/validation.ts @@ -25,6 +25,7 @@ import { assertStruct, JsonRpcIdStruct, JsonRpcParamsStruct, + JsonRpcRequestStruct, JsonRpcSuccessStruct, JsonRpcVersionStruct, JsonStruct, @@ -243,6 +244,35 @@ export function assertIsOnUserInputRequestArguments( ); } +export const OnProtocolRequestArgumentsStruct = object({ + origin: string(), + scope: ChainIdStruct, + request: JsonRpcRequestStruct, +}); + +export type OnProtocolRequestArguments = Infer< + typeof OnProtocolRequestArgumentsStruct +>; + +/** + * Asserts that the given value is a valid {@link OnProtocolRequestArguments} + * object. + * + * @param value - The value to validate. + * @throws If the value is not a valid {@link OnProtocolRequestArguments} + * object. + */ +export function assertIsOnProtocolRequestArguments( + value: unknown, +): asserts value is OnProtocolRequestArguments { + assertStruct( + value, + OnProtocolRequestArgumentsStruct, + 'Invalid request params', + rpcErrors.invalidParams, + ); +} + const OkResponseStruct = object({ id: JsonRpcIdStruct, jsonrpc: JsonRpcVersionStruct,