Skip to content

Commit

Permalink
Add Allowlist sample plugin, refactor test base
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Jun 28, 2024
1 parent 68a5847 commit 78ae1eb
Show file tree
Hide file tree
Showing 11 changed files with 690 additions and 87 deletions.
13 changes: 8 additions & 5 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,14 @@ contract UpgradeableModularAccount is
/// with user install configs.
/// @dev This function is only callable once, and only by the EntryPoint.

function initializeDefaultValidation(FunctionReference validationFunction, bytes calldata installData)
external
initializer
{
_installValidation(validationFunction, true, new bytes4[](0), installData, bytes(""));
function initializeWithValidation(
FunctionReference validationFunction,
bool shared,
bytes4[] memory selectors,
bytes calldata installData,
bytes calldata preValidationHooks
) external initializer {
_installValidation(validationFunction, shared, selectors, installData, preValidationHooks);
emit ModularAccountInitialized(_ENTRY_POINT);
}

Expand Down
142 changes: 142 additions & 0 deletions src/samples/permissionhooks/AllowlistPlugin.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";

import {PluginMetadata, PluginManifest} from "../../interfaces/IPlugin.sol";
import {IValidationHook} from "../../interfaces/IValidationHook.sol";
import {IStandardExecutor, Call} from "../../interfaces/IStandardExecutor.sol";
import {BasePlugin} from "../../plugins/BasePlugin.sol";

contract AllowlistPlugin is IValidationHook, BasePlugin {
enum FunctionId {
PRE_VALIDATION_HOOK
}

struct AllowlistInit {
address target;
bool hasSelectorAllowlist;
bytes4[] selectors;
}

struct AllowlistEntry {
bool allowed;
bool hasSelectorAllowlist;
}

mapping(address target => mapping(address account => AllowlistEntry)) public targetAllowlist;
mapping(address target => mapping(bytes4 selector => mapping(address account => bool))) public
selectorAllowlist;

error TargetNotAllowed();
error SelectorNotAllowed();
error NoSelectorSpecified();

function onInstall(bytes calldata data) external override {
AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[]));

for (uint256 i = 0; i < init.length; i++) {
targetAllowlist[init[i].target][msg.sender] = AllowlistEntry(true, init[i].hasSelectorAllowlist);

if (init[i].hasSelectorAllowlist) {
for (uint256 j = 0; j < init[i].selectors.length; j++) {
selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender] = true;
}
}
}
}

function onUninstall(bytes calldata data) external override {
AllowlistInit[] memory init = abi.decode(data, (AllowlistInit[]));

for (uint256 i = 0; i < init.length; i++) {
delete targetAllowlist[init[i].target][msg.sender];

if (init[i].hasSelectorAllowlist) {
for (uint256 j = 0; j < init[i].selectors.length; j++) {
delete selectorAllowlist[init[i].target][init[i].selectors[j]][msg.sender];
}
}
}
}

function setAllowlistTarget(address target, bool allowed, bool hasSelectorAllowlist) external {
targetAllowlist[target][msg.sender] = AllowlistEntry(allowed, hasSelectorAllowlist);
}

function setAllowlistSelector(address target, bytes4 selector, bool allowed) external {
selectorAllowlist[target][selector][msg.sender] = allowed;
}

function preUserOpValidationHook(uint8 functionId, PackedUserOperation calldata userOp, bytes32)
external
view
override
returns (uint256)
{
if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) {
_checkAllowlistCalldata(userOp.callData);
return 0;
}
revert NotImplemented();
}

function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata data, bytes calldata)
external
view
override
{
if (functionId == uint8(FunctionId.PRE_VALIDATION_HOOK)) {
_checkAllowlistCalldata(data);
return;
}

revert NotImplemented();
}

function pluginMetadata() external pure override returns (PluginMetadata memory) {
PluginMetadata memory metadata;
metadata.name = "Allowlist Plugin";
metadata.version = "v0.0.1";
metadata.author = "ERC-6900 Working Group";

return metadata;
}

// solhint-disable-next-line no-empty-blocks
function pluginManifest() external pure override returns (PluginManifest memory) {}

function _checkAllowlistCalldata(bytes calldata callData) internal view {
if (bytes4(callData[:4]) == IStandardExecutor.execute.selector) {
(address target,, bytes memory data) = abi.decode(callData[4:], (address, uint256, bytes));
_checkCallPermission(msg.sender, target, data);
} else if (bytes4(callData[:4]) == IStandardExecutor.executeBatch.selector) {
Call[] memory calls = abi.decode(callData[4:], (Call[]));

for (uint256 i = 0; i < calls.length; i++) {
_checkCallPermission(msg.sender, calls[i].target, calls[i].data);
}
}
}

function _checkCallPermission(address account, address target, bytes memory data) internal view {
AllowlistEntry storage entry = targetAllowlist[target][account];
(bool allowed, bool hasSelectorAllowlist) = (entry.allowed, entry.hasSelectorAllowlist);

if (!allowed) {
revert TargetNotAllowed();
}

if (hasSelectorAllowlist) {
if (data.length < 4) {
revert NoSelectorSpecified();
}

bytes4 selector = bytes4(data);

if (!selectorAllowlist[target][selector][account]) {
revert SelectorNotAllowed();
}
}
}
}
13 changes: 3 additions & 10 deletions test/account/AccountLoupe.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/Functio
import {ExecutionHook} from "../../src/interfaces/IAccountLoupe.sol";
import {IPluginManager} from "../../src/interfaces/IPluginManager.sol";
import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol";
import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol";

import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol";
import {AccountTestBase} from "../utils/AccountTestBase.sol";

contract AccountLoupeTest is AccountTestBase {
ComprehensivePlugin public comprehensivePlugin;

FunctionReference public ownerValidation;

event ReceivedCall(bytes msgData, uint256 msgValue);

function setUp() public {
Expand All @@ -28,10 +25,6 @@ contract AccountLoupeTest is AccountTestBase {
vm.prank(address(entryPoint));
account1.installPlugin(address(comprehensivePlugin), manifestHash, "", new FunctionReference[](0));

ownerValidation = FunctionReferenceLib.pack(
address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)
);

FunctionReference[] memory preValidationHooks = new FunctionReference[](2);
preValidationHooks[0] = FunctionReferenceLib.pack(
address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.PRE_VALIDATION_HOOK_1)
Expand All @@ -43,7 +36,7 @@ contract AccountLoupeTest is AccountTestBase {
bytes[] memory installDatas = new bytes[](2);
vm.prank(address(entryPoint));
account1.installValidation(
ownerValidation, true, new bytes4[](0), bytes(""), abi.encode(preValidationHooks, installDatas)
_ownerValidation, true, new bytes4[](0), bytes(""), abi.encode(preValidationHooks, installDatas)
);
}

Expand Down Expand Up @@ -106,7 +99,7 @@ contract AccountLoupeTest is AccountTestBase {
validations = account1.getValidations(account1.execute.selector);

assertEq(validations.length, 1);
assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(ownerValidation));
assertEq(FunctionReference.unwrap(validations[0]), FunctionReference.unwrap(_ownerValidation));
}

function test_pluginLoupe_getExecutionHooks() public {
Expand Down Expand Up @@ -147,7 +140,7 @@ contract AccountLoupeTest is AccountTestBase {
}

function test_pluginLoupe_getValidationHooks() public {
FunctionReference[] memory hooks = account1.getPreValidationHooks(ownerValidation);
FunctionReference[] memory hooks = account1.getPreValidationHooks(_ownerValidation);

assertEq(hooks.length, 2);
assertEq(
Expand Down
15 changes: 2 additions & 13 deletions test/account/DefaultValidationTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interface
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";

import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol";
import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";
import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol";

import {AccountTestBase} from "../utils/AccountTestBase.sol";
import {DefaultValidationFactoryFixture} from "../mocks/DefaultValidationFactoryFixture.sol";
Expand All @@ -16,11 +14,6 @@ contract DefaultValidationTest is AccountTestBase {

DefaultValidationFactoryFixture public defaultValidationFactoryFixture;

uint256 public constant CALL_GAS_LIMIT = 50000;
uint256 public constant VERIFICATION_GAS_LIMIT = 1200000;

FunctionReference public ownerValidation;

address public ethRecipient;

function setUp() public {
Expand All @@ -32,10 +25,6 @@ contract DefaultValidationTest is AccountTestBase {

ethRecipient = makeAddr("ethRecipient");
vm.deal(ethRecipient, 1 wei);

ownerValidation = FunctionReferenceLib.pack(
address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)
);
}

function test_defaultValidation_userOp_simple() public {
Expand All @@ -57,7 +46,7 @@ contract DefaultValidationTest is AccountTestBase {
// Generate signature
bytes32 userOpHash = entryPoint.getUserOpHash(userOp);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());
userOp.signature = _encodeSignature(ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v));
userOp.signature = _encodeSignature(_ownerValidation, DEFAULT_VALIDATION, abi.encodePacked(r, s, v));

PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
userOps[0] = userOp;
Expand All @@ -74,7 +63,7 @@ contract DefaultValidationTest is AccountTestBase {
vm.prank(owner1);
account1.executeWithAuthorization(
abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")),
_encodeSignature(ownerValidation, DEFAULT_VALIDATION, "")
_encodeSignature(_ownerValidation, DEFAULT_VALIDATION, "")
);

assertEq(ethRecipient.balance, 2 wei);
Expand Down
3 changes: 0 additions & 3 deletions test/account/MultiValidation.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ contract MultiValidationTest is AccountTestBase {
address public owner2;
uint256 public owner2Key;

uint256 public constant CALL_GAS_LIMIT = 50000;
uint256 public constant VERIFICATION_GAS_LIMIT = 1200000;

function setUp() public {
validator2 = new SingleOwnerPlugin();

Expand Down
65 changes: 27 additions & 38 deletions test/account/PerHookData.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,28 @@ pragma solidity ^0.8.25;

import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";

import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol";
import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol";
import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";

import {MockAccessControlHookPlugin} from "../mocks/plugins/MockAccessControlHookPlugin.sol";
import {Counter} from "../mocks/Counter.sol";
import {AccountTestBase} from "../utils/AccountTestBase.sol";
import {CustomValidationTestBase} from "../utils/CustomValidationTestBase.sol";

contract PerHookDataTest is AccountTestBase {
contract PerHookDataTest is CustomValidationTestBase {
using MessageHashUtils for bytes32;

MockAccessControlHookPlugin internal _accessControlHookPlugin;

Counter internal _counter;

FunctionReference internal _ownerValidation;

uint256 public constant CALL_GAS_LIMIT = 50000;
uint256 public constant VERIFICATION_GAS_LIMIT = 1200000;

function setUp() public {
_counter = new Counter();

_accessControlHookPlugin = new MockAccessControlHookPlugin();

// Write over `account1` with a new account proxy, with different initialization.

address accountImplementation = address(factory.accountImplementation());

account1 = UpgradeableModularAccount(payable(new ERC1967Proxy(accountImplementation, "")));

_ownerValidation = FunctionReferenceLib.pack(
address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)
);

FunctionReference accessControlHook = FunctionReferenceLib.pack(
address(_accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK)
);

FunctionReference[] memory preValidationHooks = new FunctionReference[](1);
preValidationHooks[0] = accessControlHook;

bytes[] memory preValidationHookData = new bytes[](1);
// Access control is restricted to only the _counter
preValidationHookData[0] = abi.encode(_counter);

bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData);

vm.prank(address(entryPoint));
account1.installValidation(
_ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks
);

vm.deal(address(account1), 100 ether);
_customValidationSetup();
}

function test_passAccessControl_userOp() public {
Expand Down Expand Up @@ -358,4 +323,28 @@ contract PerHookDataTest is AccountTestBase {

return (userOp, userOpHash);
}

// Test config

function _initialValidationConfig()
internal
virtual
override
returns (FunctionReference, bool, bytes4[] memory, bytes memory, bytes memory)
{
FunctionReference accessControlHook = FunctionReferenceLib.pack(
address(_accessControlHookPlugin), uint8(MockAccessControlHookPlugin.FunctionId.PRE_VALIDATION_HOOK)
);

FunctionReference[] memory preValidationHooks = new FunctionReference[](1);
preValidationHooks[0] = accessControlHook;

bytes[] memory preValidationHookData = new bytes[](1);
// Access control is restricted to only the counter
preValidationHookData[0] = abi.encode(_counter);

bytes memory packedPreValidationHooks = abi.encode(preValidationHooks, preValidationHookData);

return (_ownerValidation, true, new bytes4[](0), abi.encode(owner1), packedPreValidationHooks);
}
}
Loading

0 comments on commit 78ae1eb

Please sign in to comment.