Skip to content

Commit

Permalink
per validation hook data
Browse files Browse the repository at this point in the history
  • Loading branch information
adamegyed committed Jun 20, 2024
1 parent 400e833 commit 60916fa
Show file tree
Hide file tree
Showing 17 changed files with 714 additions and 100 deletions.
3 changes: 1 addition & 2 deletions src/account/AccountLoupe.sol
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ abstract contract AccountLoupe is IAccountLoupe {
override
returns (FunctionReference[] memory preValidationHooks)
{
preValidationHooks =
toFunctionReferenceArray(getAccountStorage().validationData[validationFunction].preValidationHooks);
preValidationHooks = getAccountStorage().validationData[validationFunction].preValidationHooks;
}

/// @inheritdoc IAccountLoupe
Expand Down
2 changes: 1 addition & 1 deletion src/account/AccountStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct ValidationData {
// Whether or not this validation is a signature validator.
bool isSignatureValidation;
// The pre validation hooks for this function selector.
EnumerableSet.Bytes32Set preValidationHooks;
FunctionReference[] preValidationHooks;
}

struct AccountStorage {
Expand Down
31 changes: 18 additions & 13 deletions src/account/PluginManager2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet
import {IPlugin} from "../interfaces/IPlugin.sol";
import {FunctionReference} from "../interfaces/IPluginManager.sol";
import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol";
import {AccountStorage, getAccountStorage, toSetValue, toFunctionReference} from "./AccountStorage.sol";
import {AccountStorage, getAccountStorage, toSetValue} from "./AccountStorage.sol";

// Temporary additional functions for a user-controlled install flow for validation functions.
abstract contract PluginManager2 {
using EnumerableSet for EnumerableSet.Bytes32Set;

uint8 internal constant _RESERVED_VALIDATION_DATA_INDEX = 255;

error DefaultValidationAlreadySet(FunctionReference validationFunction);
error PreValidationAlreadySet(FunctionReference validationFunction, FunctionReference preValidationFunction);
error ValidationAlreadySet(bytes4 selector, FunctionReference validationFunction);
error ValidationNotSet(bytes4 selector, FunctionReference validationFunction);
error PreValidationHookLimitExceeded();

function _installValidation(
FunctionReference validationFunction,
Expand All @@ -36,19 +39,21 @@ abstract contract PluginManager2 {
for (uint256 i = 0; i < preValidationFunctions.length; ++i) {
FunctionReference preValidationFunction = preValidationFunctions[i];

if (
!_storage.validationData[validationFunction].preValidationHooks.add(
toSetValue(preValidationFunction)
)
) {
revert PreValidationAlreadySet(validationFunction, preValidationFunction);
}
_storage.validationData[validationFunction].preValidationHooks.push(preValidationFunction);

if (initDatas[i].length > 0) {
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onInstall(initDatas[i]);
}
}

// Avoid collision between reserved index and actual indices
if (
_storage.validationData[validationFunction].preValidationHooks.length
> _RESERVED_VALIDATION_DATA_INDEX
) {
revert PreValidationHookLimitExceeded();
}
}

if (isDefault) {
Expand Down Expand Up @@ -85,16 +90,16 @@ abstract contract PluginManager2 {
bytes[] memory preValidationHookUninstallDatas = abi.decode(preValidationHookUninstallData, (bytes[]));

// Clear pre validation hooks
EnumerableSet.Bytes32Set storage preValidationHooks =
FunctionReference[] storage preValidationHooks =
_storage.validationData[validationFunction].preValidationHooks;
while (preValidationHooks.length() > 0) {
FunctionReference preValidationFunction = toFunctionReference(preValidationHooks.at(0));
preValidationHooks.remove(toSetValue(preValidationFunction));
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
for (uint256 i = 0; i < preValidationHooks.length; ++i) {
FunctionReference preValidationFunction = preValidationHooks[i];
if (preValidationHookUninstallDatas[0].length > 0) {
(address preValidationPlugin,) = FunctionReferenceLib.unpack(preValidationFunction);
IPlugin(preValidationPlugin).onUninstall(preValidationHookUninstallDatas[0]);
}
}
delete _storage.validationData[validationFunction].preValidationHooks;

// Because this function also calls `onUninstall`, and removes the default flag from validation, we must
// assume these selectors passed in to be exhaustive.
Expand Down
114 changes: 78 additions & 36 deletions src/account/UpgradeableModularAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {FunctionReferenceLib} from "../helpers/FunctionReferenceLib.sol";
import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol";
import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationDataHelpers.sol";
import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol";
import {IValidation} from "../interfaces/IValidation.sol";
Expand All @@ -19,14 +20,7 @@ import {FunctionReference, IPluginManager} from "../interfaces/IPluginManager.so
import {IStandardExecutor, Call} from "../interfaces/IStandardExecutor.sol";
import {AccountExecutor} from "./AccountExecutor.sol";
import {AccountLoupe} from "./AccountLoupe.sol";
import {
AccountStorage,
getAccountStorage,
SelectorData,
toSetValue,
toFunctionReference,
toExecutionHook
} from "./AccountStorage.sol";
import {AccountStorage, getAccountStorage, SelectorData, toSetValue, toExecutionHook} from "./AccountStorage.sol";
import {AccountStorageInitializable} from "./AccountStorageInitializable.sol";
import {PluginManagerInternals} from "./PluginManagerInternals.sol";
import {PluginManager2} from "./PluginManager2.sol";
Expand All @@ -45,6 +39,7 @@ contract UpgradeableModularAccount is
{
using EnumerableSet for EnumerableSet.Bytes32Set;
using FunctionReferenceLib for FunctionReference;
using SparseCalldataSegmentLib for bytes;

struct PostExecToRun {
bytes preExecHookReturnData;
Expand Down Expand Up @@ -77,6 +72,9 @@ contract UpgradeableModularAccount is
error UnrecognizedFunction(bytes4 selector);
error UserOpValidationFunctionMissing(bytes4 selector);
error ValidationDoesNotApply(bytes4 selector, address plugin, uint8 functionId, bool isDefault);
error SignatureSegmentOutOfOrder();
error NonCanonicalEncoding();
error ValidationSignatureSegmentMissing();

// Wraps execution of a native function with runtime validation and hooks
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installPlugin, uninstallPlugin
Expand Down Expand Up @@ -350,38 +348,50 @@ contract UpgradeableModularAccount is

_checkIfValidationApplies(selector, userOpValidationFunction, isDefaultValidation);

validationData =
_doUserOpValidation(selector, userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
validationData = _doUserOpValidation(userOpValidationFunction, userOp, userOp.signature[22:], userOpHash);
}

// To support gas estimation, we don't fail early when the failure is caused by a signature failure
function _doUserOpValidation(
bytes4 selector,
FunctionReference userOpValidationFunction,
PackedUserOperation memory userOp,
bytes calldata signature,
bytes32 userOpHash
) internal returns (uint256 validationData) {
userOp.signature = signature;
) internal returns (uint256) {
// Set up the per-hook data tracking fields
bytes calldata signatureSegment;
(signatureSegment, signature) = signature.getNextSegment();

if (userOpValidationFunction.isEmpty()) {
// If the validation function is empty, then the call cannot proceed.
revert UserOpValidationFunctionMissing(selector);
}

uint256 currentValidationData;
uint256 validationData;

// Do preUserOpValidation hooks
EnumerableSet.Bytes32Set storage preUserOpValidationHooks =
FunctionReference[] memory preUserOpValidationHooks =
getAccountStorage().validationData[userOpValidationFunction].preValidationHooks;

uint256 preUserOpValidationHooksLength = preUserOpValidationHooks.length();
for (uint256 i = 0; i < preUserOpValidationHooksLength; ++i) {
bytes32 key = preUserOpValidationHooks.at(i);
FunctionReference preUserOpValidationHook = toFunctionReference(key);
for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) {
// Load per-hook data, if any is present
// The segment index is the first byte of the signature
if (signatureSegment.getIndex() == i) {
// Use the current segment
userOp.signature = signatureSegment.getBody();

if (userOp.signature.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(signatureSegment, signature) = signature.getNextSegment();

if (signatureSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
userOp.signature = "";
}

(address plugin, uint8 functionId) = preUserOpValidationHook.unpack();
currentValidationData = IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);
(address plugin, uint8 functionId) = preUserOpValidationHooks[i].unpack();
uint256 currentValidationData =
IValidationHook(plugin).preUserOpValidationHook(functionId, userOp, userOpHash);

if (uint160(currentValidationData) > 1) {
// If the aggregator is not 0 or 1, it is an unexpected value
Expand All @@ -392,35 +402,63 @@ contract UpgradeableModularAccount is

// Run the user op validationFunction
{
if (signatureSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

userOp.signature = signatureSegment.getBody();

(address plugin, uint8 functionId) = userOpValidationFunction.unpack();
currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash);
uint256 currentValidationData = IValidation(plugin).validateUserOp(functionId, userOp, userOpHash);

if (preUserOpValidationHooksLength != 0) {
if (preUserOpValidationHooks.length != 0) {
// If we have other validation data we need to coalesce with
validationData = _coalesceValidation(validationData, currentValidationData);
} else {
validationData = currentValidationData;
}
}

return validationData;
}

function _doRuntimeValidation(
FunctionReference runtimeValidationFunction,
bytes calldata callData,
bytes calldata authorizationData
) internal {
// Set up the per-hook data tracking fields
bytes calldata authSegment;
(authSegment, authorizationData) = authorizationData.getNextSegment();

// run all preRuntimeValidation hooks
EnumerableSet.Bytes32Set storage preRuntimeValidationHooks =
FunctionReference[] memory preRuntimeValidationHooks =
getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks;

uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length();
for (uint256 i = 0; i < preRuntimeValidationHooksLength; ++i) {
bytes32 key = preRuntimeValidationHooks.at(i);
FunctionReference preRuntimeValidationHook = toFunctionReference(key);
for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) {
bytes memory currentAuthData;

if (authSegment.getIndex() == i) {
// Use the current segment
currentAuthData = authSegment.getBody();

if (currentAuthData.length == 0) {
revert NonCanonicalEncoding();
}

// Load the next per-hook data segment
(authSegment, authorizationData) = authorizationData.getNextSegment();

(address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHook.unpack();
if (authSegment.getIndex() <= i) {
revert SignatureSegmentOutOfOrder();
}
} else {
currentAuthData = "";
}

(address hookPlugin, uint8 hookFunctionId) = preRuntimeValidationHooks[i].unpack();
try IValidationHook(hookPlugin).preRuntimeValidationHook(
hookFunctionId, msg.sender, msg.value, callData
hookFunctionId, msg.sender, msg.value, callData, currentAuthData
)
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
Expand All @@ -430,9 +468,13 @@ contract UpgradeableModularAccount is
}
}

if (authSegment.getIndex() != _RESERVED_VALIDATION_DATA_INDEX) {
revert ValidationSignatureSegmentMissing();
}

(address plugin, uint8 functionId) = runtimeValidationFunction.unpack();

try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authorizationData)
try IValidation(plugin).validateRuntime(functionId, msg.sender, msg.value, callData, authSegment.getBody())
// forgefmt: disable-start
// solhint-disable-next-line no-empty-blocks
{} catch (bytes memory revertReason) {
Expand Down
41 changes: 41 additions & 0 deletions src/helpers/SparseCalldataSegmentLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

library SparseCalldataSegmentLib {
/// @notice Splits out a segment of calldata, sparsely-packed
/// @param source The calldata to extract the segment from
/// @return segment The extracted segment
/// @return remainder The remaining calldata
function getNextSegment(bytes calldata source)
internal
pure
returns (bytes calldata segment, bytes calldata remainder)
{
// The first 8 bytes hold the length of the segment, excluding the index.
uint64 length = uint64(bytes8(source[:8]));

// The offset of the remainder of the calldata.
uint256 remainderOffset = 8 + length + 1;

// The segment is the next `length` + 1 bytes, to account for the index.
// By convention, the first byte of each segment is the index of the segment.
segment = source[8:remainderOffset];

// The remainder is the rest of the calldata.
remainder = source[remainderOffset:];
}

/// @notice Extracts the index from a segment
/// @param segment The segment to extract the index from
/// @return The index of the segment
function getIndex(bytes calldata segment) internal pure returns (uint8) {
return uint8(segment[0]);
}

/// @notice Extracts the body from a segment
/// @param segment The segment to extract the body from
/// @return The body of the segment
function getBody(bytes calldata segment) internal pure returns (bytes calldata) {
return segment[1:];
}
}
9 changes: 7 additions & 2 deletions src/interfaces/IValidationHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ interface IValidationHook is IPlugin {
/// @param sender The caller address.
/// @param value The call value.
/// @param data The calldata sent.
function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data)
external;
function preRuntimeValidationHook(
uint8 functionId,
address sender,
uint256 value,
bytes calldata data,
bytes calldata authorization
) external;

// TODO: support this hook type within the account & in the manifest

Expand Down
18 changes: 13 additions & 5 deletions test/account/AccountReturnData.t.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import {FunctionReference} from "../../src/helpers/FunctionReferenceLib.sol";
import {FunctionReference, FunctionReferenceLib} from "../../src/helpers/FunctionReferenceLib.sol";
import {Call} from "../../src/interfaces/IStandardExecutor.sol";
import {ISingleOwnerPlugin} from "../../src/plugins/owner/ISingleOwnerPlugin.sol";

Expand Down Expand Up @@ -59,8 +59,12 @@ contract AccountReturnDataTest is AccountTestBase {
account1.execute,
(address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ()))
),
abi.encodePacked(
singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, SELECTOR_ASSOCIATED_VALIDATION
_encodeSignature(
FunctionReferenceLib.pack(
address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)
),
SELECTOR_ASSOCIATED_VALIDATION,
""
)
);

Expand All @@ -85,8 +89,12 @@ contract AccountReturnDataTest is AccountTestBase {

bytes memory retData = account1.executeWithAuthorization(
abi.encodeCall(account1.executeBatch, (calls)),
abi.encodePacked(
singleOwnerPlugin, ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER, SELECTOR_ASSOCIATED_VALIDATION
_encodeSignature(
FunctionReferenceLib.pack(
address(singleOwnerPlugin), uint8(ISingleOwnerPlugin.FunctionId.VALIDATION_OWNER)
),
SELECTOR_ASSOCIATED_VALIDATION,
""
)
);

Expand Down
Loading

0 comments on commit 60916fa

Please sign in to comment.