Skip to content

Commit

Permalink
feat: Test and fix logic for cancelling subscriptions and withdrawing…
Browse files Browse the repository at this point in the history
… funds from subscriptions
  • Loading branch information
mgnfy-view committed Dec 17, 2024
1 parent 689b905 commit 321f59f
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 12 deletions.
19 changes: 15 additions & 4 deletions src/SafeSubscriptions.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ contract SafeSubscriptions is EIP712, ISafeSubscriptions {
Safe private s_safe;
uint256 private s_nonce;
mapping(bytes32 subscriptionDataHash => Subscription subscription) private s_subscriptions;
mapping(bytes32 subscriptionDataHash => bool isCancelled) private s_isCancelled;

constructor(address _safe) EIP712("Safe Subscriptions", "1") {
s_safe = Safe(payable(_safe));
Expand All @@ -32,6 +33,7 @@ contract SafeSubscriptions is EIP712, ISafeSubscriptions {
bytes memory _signatures
)
external
returns (bytes32)
{
if (
_subscription.serviceProvider == address(0) || _subscription.token == address(0)
Expand Down Expand Up @@ -59,6 +61,8 @@ contract SafeSubscriptions is EIP712, ISafeSubscriptions {
s_nonce++;

emit SubscriptionCreated(_subscription);

return subscriptionDataHash;
}

function cancelSubscription(
Expand All @@ -85,7 +89,7 @@ contract SafeSubscriptions is EIP712, ISafeSubscriptions {
_checkDeadline(_deadline);
_checkNonce(_nonce);

delete s_subscriptions[_subscriptionDataHash];
s_isCancelled[_subscriptionDataHash] = true;
s_nonce++;

emit SubscriptionCancelled(_subscriptionDataHash);
Expand All @@ -94,20 +98,23 @@ contract SafeSubscriptions is EIP712, ISafeSubscriptions {
function withdrawFromSubscription(bytes32 _subscriptionDataHash) external {
Subscription memory subscription = s_subscriptions[_subscriptionDataHash];

if (subscription.serviceProvider == address(0)) revert SubscriptionDoesNotExist(_subscriptionDataHash);
if (s_isCancelled[_subscriptionDataHash]) revert SubscriptionRevoked();
if (subscription.startingTimestamp > block.timestamp) {
revert SubscriptionHasNotStartedYet(_subscriptionDataHash);
}

uint256 roundsToClaim;
if (subscription.isRecurring) {
roundsToClaim = (block.timestamp - subscription.startingTimestamp / subscription.duration)
roundsToClaim = ((block.timestamp - subscription.startingTimestamp) / subscription.duration)
- subscription.roundsClaimedSoFar;
} else {
roundsToClaim = (block.timestamp - subscription.startingTimestamp / subscription.duration);
roundsToClaim = (block.timestamp - subscription.startingTimestamp) / subscription.duration;
if (roundsToClaim > subscription.rounds) roundsToClaim = subscription.rounds;
roundsToClaim -= subscription.roundsClaimedSoFar;
}
uint256 amountToWithdraw = roundsToClaim * subscription.amount;
if (amountToWithdraw == 0) revert ZeroAmountToWithdraw();
s_subscriptions[_subscriptionDataHash].roundsClaimedSoFar += roundsToClaim;

bool success;
Expand Down Expand Up @@ -171,10 +178,14 @@ contract SafeSubscriptions is EIP712, ISafeSubscriptions {
return s_nonce + 1;
}

function getSbscriptionData(bytes32 _subscriptionDataHash) external view returns (Subscription memory) {
function getSubscriptionData(bytes32 _subscriptionDataHash) external view returns (Subscription memory) {
return s_subscriptions[_subscriptionDataHash];
}

function isSubscriptionCancelled(bytes32 _subscriptionDataHash) external view returns (bool) {
return s_isCancelled[_subscriptionDataHash];
}

function getEncodedSubscriptionDataAndHash(
Subscription memory _subscription,
uint256 _deadline,
Expand Down
4 changes: 3 additions & 1 deletion src/interfaces/ISafeSubscriptions.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ interface ISafeSubscriptions {
error InvalidSubscription();
error SubscriptionAlreadyExists(bytes32 subscriptionDataHash);
error SubscriptionDoesNotExist(bytes32 subscriptionDataHash);
error SubscriptionRevoked();
error SubscriptionHasNotStartedYet(bytes32 subscriptionDataHash);
error ZeroAmountToWithdraw();
error TransactionFailed();
error DeadlinePassed(uint256 deadline, uint256 currentTimestamp);
error InvalidNonce(uint256 givenNonce, uint256 _expectedNonce);
error InvalidNonce(uint256 givenNonce, uint256 expectedNonce);
}
72 changes: 72 additions & 0 deletions test/unit/CancelSubscription.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.24;

import { ISafeSubscriptions } from "../../src/interfaces/ISafeSubscriptions.sol";

import { GlobalHelper } from "../utils/GlobalHelper.sol";

contract CancelSubscriptionTest is GlobalHelper {
function test_cancellingSubscriptionFailsIfSubscriptionDoesNotExist() public {
(ISafeSubscriptions.Subscription memory subscriptionData,,) = _getTestCreateSubscriptionData();
bytes32 subscriptionDataHash = safeSubscriptions.getSubscriptionDataHash(subscriptionData);

vm.expectRevert(
abi.encodeWithSelector(ISafeSubscriptions.SubscriptionDoesNotExist.selector, subscriptionDataHash)
);
safeSubscriptions.cancelSubscription(subscriptionDataHash, 0, 0, "");
}

function test_cancellingSubscriptionFailsIfDeadlineHasPassed() public {
bytes32 subscriptionDataHash = _createTestSubscription();
ISafeSubscriptions.Subscription memory subscriptionData =
safeSubscriptions.getSubscriptionData(subscriptionDataHash);

_warpBy(delay);

uint256 deadline = block.timestamp - 1;
uint256 nonce = safeSubscriptions.getNextNonce();
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscriptionData, deadline, nonce);

vm.expectRevert(abi.encodeWithSelector(ISafeSubscriptions.DeadlinePassed.selector, deadline, block.timestamp));
safeSubscriptions.cancelSubscription(subscriptionDataHash, deadline, nonce, signatures);
}

function test_cancellingSubscriptionFailsIfInvalidNonceIsPassed() public {
bytes32 subscriptionDataHash = _createTestSubscription();
ISafeSubscriptions.Subscription memory subscriptionData =
safeSubscriptions.getSubscriptionData(subscriptionDataHash);
uint256 deadline = block.timestamp + delay;
uint256 nonce = safeSubscriptions.getNextNonce() + 1;
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscriptionData, deadline, nonce);

vm.expectRevert(abi.encodeWithSelector(ISafeSubscriptions.InvalidNonce.selector, nonce, nonce - 1));
safeSubscriptions.cancelSubscription(subscriptionDataHash, deadline, nonce, signatures);
}

function test_cancellingSubscriptionSucceeds() public {
bytes32 subscriptionDataHash = _createTestSubscription();
ISafeSubscriptions.Subscription memory subscriptionData =
safeSubscriptions.getSubscriptionData(subscriptionDataHash);
uint256 deadline = block.timestamp + delay;
uint256 nonce = safeSubscriptions.getNextNonce();
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscriptionData, deadline, nonce);

safeSubscriptions.cancelSubscription(subscriptionDataHash, deadline, nonce, signatures);

assertEq(safeSubscriptions.getNextNonce(), nonce + 1);
assertTrue(safeSubscriptions.isSubscriptionCancelled(subscriptionDataHash));
}

function test_cancellingSubscriptionEmitsEvent() public {
bytes32 subscriptionDataHash = _createTestSubscription();
ISafeSubscriptions.Subscription memory subscriptionData =
safeSubscriptions.getSubscriptionData(subscriptionDataHash);
uint256 deadline = block.timestamp + delay;
uint256 nonce = safeSubscriptions.getNextNonce();
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscriptionData, deadline, nonce);

vm.expectEmit(true, true, true, true);
emit ISafeSubscriptions.SubscriptionCancelled(subscriptionDataHash);
safeSubscriptions.cancelSubscription(subscriptionDataHash, deadline, nonce, signatures);
}
}
12 changes: 6 additions & 6 deletions test/unit/CreateSubscription.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ contract CreateSubscriptionTest is GlobalHelper {
(ISafeSubscriptions.Subscription memory subscription, uint256 deadline, uint256 nonce) =
_getTestCreateSubscriptionData();
subscription.startingTimestamp = block.timestamp + 2 * delay;
bytes memory signatures = _getSignaturesForSubscriptionCreation(subscription, deadline, nonce);
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscription, deadline, nonce);

_warpBy(delay + 1);

Expand All @@ -94,7 +94,7 @@ contract CreateSubscriptionTest is GlobalHelper {
(ISafeSubscriptions.Subscription memory subscription, uint256 deadline, uint256 nonce) =
_getTestCreateSubscriptionData();
nonce = 10;
bytes memory signatures = _getSignaturesForSubscriptionCreation(subscription, deadline, nonce);
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscription, deadline, nonce);

vm.expectRevert(
abi.encodeWithSelector(ISafeSubscriptions.InvalidNonce.selector, nonce, safeSubscriptions.getNextNonce())
Expand All @@ -105,12 +105,12 @@ contract CreateSubscriptionTest is GlobalHelper {
function test_creatingANewSubscriptionSucceeds() public {
(ISafeSubscriptions.Subscription memory subscription, uint256 deadline, uint256 nonce) =
_getTestCreateSubscriptionData();
bytes memory signatures = _getSignaturesForSubscriptionCreation(subscription, deadline, nonce);
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscription, deadline, nonce);

safeSubscriptions.createSubscription(subscription, deadline, nonce, signatures);
bytes32 subscriptionDataHash = safeSubscriptions.createSubscription(subscription, deadline, nonce, signatures);

ISafeSubscriptions.Subscription memory subscriptionData =
safeSubscriptions.getSbscriptionData(safeSubscriptions.getSubscriptionDataHash(subscription));
safeSubscriptions.getSubscriptionData(subscriptionDataHash);
assertEq(subscriptionData.serviceProvider, subscription.serviceProvider);
assertEq(subscriptionData.token, subscription.token);
assertEq(subscriptionData.amount, subscription.amount);
Expand All @@ -127,7 +127,7 @@ contract CreateSubscriptionTest is GlobalHelper {
function test_creatingANewSubscriptionEmitsEvent() public {
(ISafeSubscriptions.Subscription memory subscription, uint256 deadline, uint256 nonce) =
_getTestCreateSubscriptionData();
bytes memory signatures = _getSignaturesForSubscriptionCreation(subscription, deadline, nonce);
bytes memory signatures = _getSignaturesForSubscriptionOperation(subscription, deadline, nonce);

vm.expectEmit(true, true, true, true);
emit ISafeSubscriptions.SubscriptionCreated(subscription);
Expand Down
Loading

0 comments on commit 321f59f

Please sign in to comment.