diff --git a/starknet/contracts/token_vesting/.gitignore b/starknet/contracts/token_vesting/.gitignore new file mode 100755 index 0000000..98ab566 --- /dev/null +++ b/starknet/contracts/token_vesting/.gitignore @@ -0,0 +1,3 @@ +target +.snfoundry_cache/ +deployment diff --git a/starknet/contracts/token_vesting/.tool-versions b/starknet/contracts/token_vesting/.tool-versions new file mode 100644 index 0000000..d03b7e6 --- /dev/null +++ b/starknet/contracts/token_vesting/.tool-versions @@ -0,0 +1,3 @@ +starknet 2.9.2 +starknet-foundry 0.34.0 +scarb 2.9.2 diff --git a/starknet/contracts/token_vesting/Scarb.lock b/starknet/contracts/token_vesting/Scarb.lock new file mode 100644 index 0000000..96362ba --- /dev/null +++ b/starknet/contracts/token_vesting/Scarb.lock @@ -0,0 +1,63 @@ +# Code generated by scarb DO NOT EDIT. +version = 1 + +[[package]] +name = "openzeppelin_access" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" +dependencies = [ + "openzeppelin_introspection", +] + +[[package]] +name = "openzeppelin_account" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" +dependencies = [ + "openzeppelin_introspection", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_introspection" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" + +[[package]] +name = "openzeppelin_token" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" +dependencies = [ + "openzeppelin_access", + "openzeppelin_account", + "openzeppelin_introspection", + "openzeppelin_utils", +] + +[[package]] +name = "openzeppelin_utils" +version = "0.20.0" +source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.20.0#7756fd1de2b4ebd239fa6e372d75535cea02e5e5" + +[[package]] +name = "snforge_scarb_plugin" +version = "0.34.0" +source = "git+https://github.com/foundry-rs/starknet-foundry?tag=v0.34.0#d6976d4635cbe69bd199fd502788c469d408ed2d" + +[[package]] +name = "snforge_std" +version = "0.34.0" +source = "git+https://github.com/foundry-rs/starknet-foundry?tag=v0.34.0#d6976d4635cbe69bd199fd502788c469d408ed2d" +dependencies = [ + "snforge_scarb_plugin", +] + +[[package]] +name = "token_vesting" +version = "0.1.0" +dependencies = [ + "openzeppelin_access", + "openzeppelin_token", + "openzeppelin_utils", + "snforge_std", +] diff --git a/starknet/contracts/token_vesting/Scarb.toml b/starknet/contracts/token_vesting/Scarb.toml new file mode 100644 index 0000000..3756f68 --- /dev/null +++ b/starknet/contracts/token_vesting/Scarb.toml @@ -0,0 +1,53 @@ +[package] +name = "token_vesting" +version = "0.1.0" +edition = "2024_07" + +# See more keys and their definitions at https://docs.swmansion.com/scarb/docs/reference/manifest.html + +[dependencies] +starknet = "2.9.2" +openzeppelin_access = { git = "https://github.com/OpenZeppelin/cairo-contracts.git", tag = "v0.20.0" } +openzeppelin_token = { git = "https://github.com/OpenZeppelin/cairo-contracts.git", tag = "v0.20.0" } +openzeppelin_utils = { git = "https://github.com/OpenZeppelin/cairo-contracts.git", tag = "v0.20.0" } + + +[dev-dependencies] +snforge_std = { git = "https://github.com/foundry-rs/starknet-foundry", tag = "v0.34.0" } +assert_macros = "2.9.2" + +[[target.starknet-contract]] +sierra = true + +[scripts] +test = "snforge test" + +# Visit https://foundry-rs.github.io/starknet-foundry/appendix/scarb-toml.html for more information + +# [tool.snforge] # Define `snforge` tool section +# exit_first = true # Stop tests execution immediately upon the first failure +# fuzzer_runs = 1234 # Number of runs of the random fuzzer +# fuzzer_seed = 1111 # Seed for the random fuzzer + +# [[tool.snforge.fork]] # Used for fork testing +# name = "SOME_NAME" # Fork name +# url = "http://your.rpc.url" # Url of the RPC provider +# block_id.tag = "latest" # Block to fork from (block tag) + +# [[tool.snforge.fork]] +# name = "SOME_SECOND_NAME" +# url = "http://your.second.rpc.url" +# block_id.number = "123" # Block to fork from (block number) + +# [[tool.snforge.fork]] +# name = "SOME_THIRD_NAME" +# url = "http://your.third.rpc.url" +# block_id.hash = "0x123" # Block to fork from (block hash) + +# [profile.dev.cairo] # Configure Cairo compiler +# unstable-add-statements-code-locations-debug-info = true # Should be used if you want to use coverage +# unstable-add-statements-functions-debug-info = true # Should be used if you want to use coverage/profiler +# inlining-strategy = "avoid" # Should be used if you want to use coverage + +# [features] # Used for conditional compilation +# enable_for_tests = [] # Feature name and list of other features that should be enabled with it diff --git a/starknet/contracts/token_vesting/snfoundry.toml b/starknet/contracts/token_vesting/snfoundry.toml new file mode 100644 index 0000000..306a097 --- /dev/null +++ b/starknet/contracts/token_vesting/snfoundry.toml @@ -0,0 +1,11 @@ +# Visit https://foundry-rs.github.io/starknet-foundry/appendix/snfoundry-toml.html +# and https://foundry-rs.github.io/starknet-foundry/projects/configuration.html for more information + +# [sncast.default] # Define a profile name +# url = "https://free-rpc.nethermind.io/sepolia-juno/v0_7" # Url of the RPC provider +# accounts-file = "../account-file" # Path to the file with the account data +# account = "mainuser" # Account from `accounts_file` or default account file that will be used for the transactions +# keystore = "~/keystore" # Path to the keystore file +# wait-params = { timeout = 300, retry-interval = 10 } # Wait for submitted transaction parameters +# block-explorer = "StarkScan" # Block explorer service used to display links to transaction details +# show-explorer-links = true # Print links pointing to pages with transaction details in the chosen block explorer diff --git a/starknet/contracts/token_vesting/src/lib.cairo b/starknet/contracts/token_vesting/src/lib.cairo new file mode 100644 index 0000000..af05e6f --- /dev/null +++ b/starknet/contracts/token_vesting/src/lib.cairo @@ -0,0 +1,2 @@ +pub mod vesting; +pub mod mocks; diff --git a/starknet/contracts/token_vesting/src/mocks.cairo b/starknet/contracts/token_vesting/src/mocks.cairo new file mode 100644 index 0000000..b51ad9e --- /dev/null +++ b/starknet/contracts/token_vesting/src/mocks.cairo @@ -0,0 +1 @@ +pub mod free_erc20; diff --git a/starknet/contracts/token_vesting/src/mocks/free_erc20.cairo b/starknet/contracts/token_vesting/src/mocks/free_erc20.cairo new file mode 100644 index 0000000..3589d96 --- /dev/null +++ b/starknet/contracts/token_vesting/src/mocks/free_erc20.cairo @@ -0,0 +1,55 @@ +use starknet::ContractAddress; + +//mock erc20 +#[starknet::interface] +pub trait IFreeMint { + fn mint(ref self: T, recipient: ContractAddress, amount: u256); +} + +#[starknet::contract] +mod FreeMintERC20 { + use openzeppelin_token::erc20::ERC20Component; + use openzeppelin_token::erc20::ERC20HooksEmptyImpl; + use starknet::ContractAddress; + use super::IFreeMint; + + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl ERC20MetadataImpl = ERC20Component::ERC20MetadataImpl; + #[abi(embed_v0)] + impl ERC20CamelOnlyImpl = ERC20Component::ERC20CamelOnlyImpl; + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + erc20: ERC20Component::Storage, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC20Event: ERC20Component::Event, + } + + #[constructor] + fn constructor( + ref self: ContractState, + initial_supply: u256, + name: core::byte_array::ByteArray, + symbol: core::byte_array::ByteArray, + ) { + self.erc20.initializer(name, symbol); + } + + #[abi(embed_v0)] + impl ImplFreeMint of IFreeMint { + fn mint(ref self: ContractState, recipient: ContractAddress, amount: u256) { + self.erc20.mint(recipient, amount); + } + } +} diff --git a/starknet/contracts/token_vesting/src/vesting.cairo b/starknet/contracts/token_vesting/src/vesting.cairo new file mode 100644 index 0000000..4d51dce --- /dev/null +++ b/starknet/contracts/token_vesting/src/vesting.cairo @@ -0,0 +1,299 @@ +use starknet::ContractAddress; + +#[derive(Drop, Serde, Copy, starknet::Store)] +pub struct Schedule { + pub recipient: ContractAddress, + pub token: ContractAddress, + pub start_time: u64, + pub cliff_time: u64, + pub end_time: u64, + pub total_claimed: u256, + pub total_amount: u256, +} + +#[starknet::interface] +pub trait IOwnable { + fn owner(self: @ContractState) -> ContractAddress; +} + +#[starknet::interface] +pub trait IVesting { + fn add_schedule( + ref self: ContractState, + token: ContractAddress, + recipient: ContractAddress, + start_time: u64, + cliff_time: u64, + end_time: u64, + total_amount: u256, + ); + + fn remove_schedule( + ref self: ContractState, + token: ContractAddress, + address_to_end: ContractAddress, + refund_address: ContractAddress, + ); + + fn claim(ref self: ContractState, token: ContractAddress); + + fn get_vested_amount(self: @ContractState, user: ContractAddress) -> u256; + + fn get_claimable_amount(self: @ContractState, user: ContractAddress) -> u256; + + fn get_user_vesting_schedule(self: @ContractState, user: ContractAddress) -> Schedule; +} + +#[starknet::contract] +pub mod Vesting { + use core::num::traits::Zero; + use starknet::event::EventEmitter; + use super::{IVesting, Schedule}; + use openzeppelin_access::ownable::OwnableComponent; + use openzeppelin_token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + use core::starknet::{ + ContractAddress, get_block_timestamp, get_caller_address, get_contract_address, + contract_address_const, + }; + use core::starknet::storage::{ + StoragePointerReadAccess, StoragePointerWriteAccess, Map, StoragePathEntry, + }; + + // Ownable Component + component!(path: OwnableComponent, storage: ownable, event: OwnableEvent); + + // Ownable Mixin + #[abi(embed_v0)] + impl OwnableMixinImpl = OwnableComponent::OwnableMixinImpl; + impl OwnableInternalImpl = OwnableComponent::InternalImpl; + + #[storage] + struct Storage { + #[substorage(v0)] + ownable: OwnableComponent::Storage, + schedules: Map, + } + + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event { + #[flat] + OwnableEvent: OwnableComponent::Event, + NewScheduleAdded: NewScheduleAdded, + SuccessfulClaim: SuccessfulClaim, + VestingEndedByOwner: VestingEndedByOwner, + } + + #[derive(Drop, starknet::Event)] + pub struct NewScheduleAdded { + pub recipient: ContractAddress, + pub token: ContractAddress, + pub start_time: u64, + pub cliff_time: u64, + pub end_time: u64, + pub amount: u256, + } + + #[derive(Drop, starknet::Event)] + pub struct SuccessfulClaim { + pub recipient: ContractAddress, + pub token: ContractAddress, + pub amount: u256, + } + + #[derive(Drop, starknet::Event)] + pub struct VestingEndedByOwner { + pub address_ended: ContractAddress, + pub token: ContractAddress, + pub amount_withdrawn: u256, + pub amount_refunded: u256, + } + + pub mod Errors { + pub const ZERO_ADDRESS: felt252 = 'Zero address detected'; + pub const ZERO_AMOUNT: felt252 = 'Amount cannot be zero'; + pub const INVALID_CLIFF_TIME: felt252 = 'Cliff time is invalid'; + pub const INVALID_END_TIME: felt252 = 'End time is invalid'; + pub const INVALID_PERCENTAGE: felt252 = 'Percentage greater than 100'; + pub const TOKEN_TRANSFER_FAILED: felt252 = 'Token transfer failed'; + pub const ALREADY_HAS_LOCK: felt252 = 'User already has lock'; + } + + #[constructor] + fn constructor(ref self: ContractState, owner: ContractAddress) { + assert(!owner.is_zero(), Errors::ZERO_ADDRESS); + self.ownable.initializer(owner); + } + + #[abi(embed_v0)] + impl VestingImpl of IVesting { + fn add_schedule( + ref self: ContractState, + token: ContractAddress, + recipient: ContractAddress, + start_time: u64, + cliff_time: u64, + end_time: u64, + total_amount: u256, + ) { + self.ownable.assert_only_owner(); + assert(total_amount > 0, Errors::ZERO_AMOUNT); + assert(cliff_time >= start_time, Errors::INVALID_CLIFF_TIME); + assert(end_time >= cliff_time, Errors::INVALID_END_TIME); + + let schedule = self.get_user_vesting_schedule(recipient); + + assert(schedule.total_amount == 0, Errors::ALREADY_HAS_LOCK); + + let this_contract = get_contract_address(); + + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + let caller = get_caller_address(); + + assert( + token_dispatcher.transfer_from(caller, this_contract, total_amount), + Errors::TOKEN_TRANSFER_FAILED, + ); + + let new_schedule = Schedule { + recipient: recipient, + token: token, + start_time: start_time, + cliff_time: cliff_time, + end_time: end_time, + total_claimed: 0, + total_amount: total_amount, + }; + + self.schedules.entry(recipient).write(new_schedule); + self + .emit( + NewScheduleAdded { + recipient: recipient, + token: token, + start_time: start_time, + cliff_time: cliff_time, + end_time: end_time, + amount: total_amount, + }, + ); + } + + fn remove_schedule( + ref self: ContractState, + token: ContractAddress, + address_to_end: ContractAddress, + refund_address: ContractAddress, + ) { + self.ownable.assert_only_owner(); + let mut amount_refundable = 0; + let mut amount_withdrawable = 0; + + let schedule = self.get_user_vesting_schedule(address_to_end); + + if get_block_timestamp() < schedule.cliff_time { + amount_refundable = schedule.total_amount; + } else { + let amount_vested = self.get_vested_amount(address_to_end); + amount_withdrawable = amount_vested - schedule.total_claimed; + amount_refundable = schedule.total_amount - amount_vested; + } + + if amount_refundable > 0 { + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + + assert( + token_dispatcher.transfer(refund_address, amount_refundable), + Errors::TOKEN_TRANSFER_FAILED, + ); + } + + if amount_withdrawable > 0 { + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + + assert( + token_dispatcher.transfer(address_to_end, amount_withdrawable), + Errors::TOKEN_TRANSFER_FAILED, + ); + } + + let empty_schedule = Schedule { + recipient: self.zero_address(), + token: self.zero_address(), + start_time: 0, + cliff_time: 0, + end_time: 0, + total_claimed: 0, + total_amount: 0, + }; + + self.schedules.entry(address_to_end).write(empty_schedule); + + self + .emit( + VestingEndedByOwner { + address_ended: address_to_end, + token: token, + amount_withdrawn: amount_withdrawable, + amount_refunded: amount_refundable, + }, + ) + } + + fn claim(ref self: ContractState, token: ContractAddress) { + let caller = get_caller_address(); + let schedule = self.get_user_vesting_schedule(caller); + let claimable = self.get_claimable_amount(caller); + + if claimable > 0 { + let mut updated_schedule = schedule; + updated_schedule.total_claimed = updated_schedule.total_claimed + claimable; + self.schedules.entry(caller).write(updated_schedule); + + let token_dispatcher = IERC20Dispatcher { contract_address: token }; + + assert(token_dispatcher.transfer(caller, claimable), Errors::TOKEN_TRANSFER_FAILED); + + self.emit(SuccessfulClaim { recipient: caller, token: token, amount: claimable }); + } + } + + fn get_vested_amount(self: @ContractState, user: ContractAddress) -> u256 { + let schedule = self.get_user_vesting_schedule(user); + + let now = get_block_timestamp(); + + if now < schedule.cliff_time { + 0 + } else if now >= schedule.end_time { + schedule.total_amount + } else { + let elapsed_time = now - schedule.start_time; + let total_duration = schedule.end_time - schedule.start_time; + (schedule.total_amount * elapsed_time.into()) / total_duration.into() + } + } + + fn get_claimable_amount(self: @ContractState, user: ContractAddress) -> u256 { + let schedule = self.get_user_vesting_schedule(user); + let vested_amount = self.get_vested_amount(user); + + if vested_amount > schedule.total_claimed { + vested_amount - schedule.total_claimed + } else { + 0 + } + } + + fn get_user_vesting_schedule(self: @ContractState, user: ContractAddress) -> Schedule { + self.schedules.entry(user).read() + } + } + + #[generate_trait] + impl InternalFunctions of InternalFunctionsTrait { + fn zero_address(self: @ContractState) -> ContractAddress { + contract_address_const::<0>() + } + } +} diff --git a/starknet/contracts/token_vesting/tests/utils.cairo b/starknet/contracts/token_vesting/tests/utils.cairo new file mode 100644 index 0000000..919a27f --- /dev/null +++ b/starknet/contracts/token_vesting/tests/utils.cairo @@ -0,0 +1,82 @@ +use snforge_std::DeclareResultTrait; +use starknet::{ContractAddress, get_block_timestamp, contract_address_const}; + +use openzeppelin_utils::serde::SerializedAppend; +use snforge_std::{declare, ContractClassTrait}; + +use token_vesting::vesting::{IVestingDispatcher}; +use openzeppelin_token::erc20::interface::{IERC20Dispatcher}; +use token_vesting::mocks::free_erc20::{IFreeMintDispatcher, IFreeMintDispatcherTrait}; + +pub const ONE_E18: u256 = 1000000000000000000_u256; + +pub fn OWNER() -> ContractAddress { + contract_address_const::<'OWNER'>() +} + +pub fn RECIPIENT() -> ContractAddress { + contract_address_const::<'RECIPIENT'>() +} + +pub fn OTHER() -> ContractAddress { + contract_address_const::<'OTHER'>() +} + +pub fn OTHER_ADMIN() -> ContractAddress { + contract_address_const::<'OTHER_ADMIN'>() +} + +pub fn ZERO_ADDRESS() -> ContractAddress { + contract_address_const::<0>() +} + +pub fn declare_and_deploy(contract_name: ByteArray, calldata: Array) -> ContractAddress { + let contract = declare(contract_name).unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + contract_address +} + +pub fn deploy_erc20() -> ContractAddress { + let mut calldata = array![]; + let initial_supply: u256 = 1000_000_000_u256; + let name: ByteArray = "DummyERC20"; + let symbol: ByteArray = "DUMMY"; + + calldata.append_serde(initial_supply); + calldata.append_serde(name); + calldata.append_serde(symbol); + let erc20_address = declare_and_deploy("FreeMintERC20", calldata); + + erc20_address +} + +pub fn deploy_vesting_contract() -> IVestingDispatcher { + let mut calldata = array![]; + calldata.append_serde(OWNER()); + let vesting_contract = declare_and_deploy("Vesting", calldata); + IVestingDispatcher { contract_address: vesting_contract } +} + +pub fn setup() -> (IVestingDispatcher, IERC20Dispatcher) { + let erc20_address = deploy_erc20(); + let initial_amount: u256 = 1_000_000_u256 * ONE_E18; + IFreeMintDispatcher { contract_address: erc20_address }.mint(OWNER(), initial_amount); + let erc20_contract = IERC20Dispatcher { contract_address: erc20_address }; + let vesting_contract = deploy_vesting_contract(); + (vesting_contract, erc20_contract) +} + + +pub fn generate_schedule(duration_in_secs: u64, cliff: bool) -> (u64, u64, u64) { + let start_time = get_block_timestamp() + 1000_u64; + + let cliff_time = if cliff { + start_time + (duration_in_secs / 5_u64) // 20% = 1/5 + } else { + start_time + }; + + let end_time = start_time + duration_in_secs; + + (start_time, cliff_time, end_time) +} diff --git a/starknet/contracts/token_vesting/tests/vesting.cairo b/starknet/contracts/token_vesting/tests/vesting.cairo new file mode 100644 index 0000000..8419341 --- /dev/null +++ b/starknet/contracts/token_vesting/tests/vesting.cairo @@ -0,0 +1,6 @@ +#[cfg(test)] +mod test_admin; + +#[cfg(test)] +mod test_vesting; + diff --git a/starknet/contracts/token_vesting/tests/vesting/test_admin.cairo b/starknet/contracts/token_vesting/tests/vesting/test_admin.cairo new file mode 100644 index 0000000..f1244d6 --- /dev/null +++ b/starknet/contracts/token_vesting/tests/vesting/test_admin.cairo @@ -0,0 +1,168 @@ +use snforge_std::EventSpyAssertionsTrait; + +use snforge_std::{spy_events, cheat_caller_address, CheatSpan}; + +use token_vesting::vesting::{Vesting, IVestingDispatcherTrait}; +use openzeppelin_token::erc20::interface::{IERC20DispatcherTrait}; + +use crate::utils::*; + +#[test] +fn test_add_schedule_without_cliff() { + let (vesting_contract, erc20_token) = setup(); + + let mut spy = spy_events(); + + let (start_time, cliff_time, end_time) = generate_schedule(2000, false); + let amount = 10_000_u256 * ONE_E18; + + cheat_caller_address(erc20_token.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + erc20_token.approve(vesting_contract.contract_address, amount); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, RECIPIENT(), start_time, cliff_time, end_time, amount, + ); + + let expected_event = Vesting::Event::NewScheduleAdded( + Vesting::NewScheduleAdded { + recipient: RECIPIENT(), + token: erc20_token.contract_address, + start_time: start_time, + cliff_time: cliff_time, + end_time: end_time, + amount: amount, + }, + ); + + spy.assert_emitted(@array![(vesting_contract.contract_address, expected_event)]); + + let user_schedule = vesting_contract.get_user_vesting_schedule(RECIPIENT()); + + assert!(RECIPIENT() == user_schedule.recipient, "wrong recipient in record"); + assert!( + erc20_token.balance_of(vesting_contract.contract_address) == amount, + "vesting_contract not incremented", + ) +} + +#[test] +fn test_add_schedule_with_cliff() { + let (vesting_contract, erc20_token) = setup(); + + let mut spy = spy_events(); + + let (start_time, cliff_time, end_time) = generate_schedule(2000, true); + let amount = 10000_u256 * ONE_E18; + + cheat_caller_address(erc20_token.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + erc20_token.approve(vesting_contract.contract_address, amount); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, RECIPIENT(), start_time, cliff_time, end_time, amount, + ); + + let expected_event = Vesting::Event::NewScheduleAdded( + Vesting::NewScheduleAdded { + recipient: RECIPIENT(), + token: erc20_token.contract_address, + start_time: start_time, + cliff_time: cliff_time, + end_time: end_time, + amount: amount, + }, + ); + + spy.assert_emitted(@array![(vesting_contract.contract_address, expected_event)]); + + let user_schedule = vesting_contract.get_user_vesting_schedule(RECIPIENT()); + + assert!(RECIPIENT() == user_schedule.recipient, "wrong recipient in record"); + assert!(amount == user_schedule.total_amount, "wrong recipient in record"); + assert!( + erc20_token.balance_of(vesting_contract.contract_address) == amount, + "vesting_contract not incremented", + ) +} + +#[test] +#[should_panic] +fn test_admin_cannot_add_schedule_for_same_user() { + let (vesting_contract, erc20_token) = setup(); + + let (start_time, cliff_time, end_time) = generate_schedule(2000, true); + let amount = 10000_u256 * ONE_E18; + + cheat_caller_address(erc20_token.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + erc20_token.approve(vesting_contract.contract_address, amount); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, RECIPIENT(), start_time, cliff_time, end_time, amount, + ); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, RECIPIENT(), start_time, cliff_time, end_time, amount, + ); +} + +#[test] +#[should_panic] +fn test_not_admin_cannot_add_schedule() { + let (vesting_contract, erc20_token) = setup(); + + let (start_time, cliff_time, end_time) = generate_schedule(2000, true); + let amount = 10000_u256 * ONE_E18; + + cheat_caller_address( + vesting_contract.contract_address, OTHER_ADMIN(), CheatSpan::TargetCalls(1), + ); + vesting_contract + .add_schedule( + erc20_token.contract_address, RECIPIENT(), start_time, cliff_time, end_time, amount, + ); +} +#[test] +#[should_panic] +fn test_not_admin_cannot_add_schedule_with_invalid_cliff_time() { + let (vesting_contract, erc20_token) = setup(); + + let (start_time, _, end_time) = generate_schedule(2000, true); + let amount = 10000_u256 * ONE_E18; + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, + RECIPIENT(), + start_time, + start_time - 1000_u64, + end_time, + amount, + ); +} +#[test] +#[should_panic] +fn test_not_admin_cannot_add_schedule_with_invalid_end_time() { + let (vesting_contract, erc20_token) = setup(); + + let (start_time, cliff_time, _) = generate_schedule(2000, true); + let amount = 10000_u256 * ONE_E18; + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, + RECIPIENT(), + start_time, + cliff_time, + cliff_time - 1000_u64, + amount, + ); +} diff --git a/starknet/contracts/token_vesting/tests/vesting/test_vesting.cairo b/starknet/contracts/token_vesting/tests/vesting/test_vesting.cairo new file mode 100644 index 0000000..0680784 --- /dev/null +++ b/starknet/contracts/token_vesting/tests/vesting/test_vesting.cairo @@ -0,0 +1,265 @@ +use snforge_std::EventSpyAssertionsTrait; +use starknet::{get_block_timestamp}; + +use snforge_std::{ + spy_events, cheat_caller_address, CheatSpan, start_cheat_block_timestamp_global, + stop_cheat_block_timestamp_global, +}; + +use token_vesting::vesting::{Vesting, IVestingDispatcher, IVestingDispatcherTrait}; +use openzeppelin_token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + +use crate::utils::*; + +const amount: u256 = 10_000_u256 * ONE_E18; + +fn add_schedule() -> (IVestingDispatcher, IERC20Dispatcher) { + let (vesting_contract, erc20_token) = setup(); + + let mut spy = spy_events(); + let duration = 20_000; + let (start_time, cliff_time, end_time) = generate_schedule(duration, true); + let amount = 10_000_u256 * ONE_E18; + + cheat_caller_address(erc20_token.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + erc20_token.approve(vesting_contract.contract_address, amount); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract + .add_schedule( + erc20_token.contract_address, RECIPIENT(), start_time, cliff_time, end_time, amount, + ); + + let expected_event = Vesting::Event::NewScheduleAdded( + Vesting::NewScheduleAdded { + recipient: RECIPIENT(), + token: erc20_token.contract_address, + start_time: start_time, + cliff_time: cliff_time, + end_time: end_time, + amount: amount, + }, + ); + + spy.assert_emitted(@array![(vesting_contract.contract_address, expected_event)]); + + (vesting_contract, erc20_token) +} + + +#[test] +fn test_user_cannot_claim_if_cliff_not_reached() { + let (vesting_contract, erc20_token) = add_schedule(); + + let time_stamp = get_block_timestamp(); + let prev_balance = erc20_token.balance_of(RECIPIENT()); + + // duration is 20_000secs i.e cliff is set to 20% + start_cheat_block_timestamp_global(time_stamp + 1000); + + let claimable = vesting_contract.get_claimable_amount(RECIPIENT()); + + cheat_caller_address(vesting_contract.contract_address, RECIPIENT(), CheatSpan::TargetCalls(1)); + vesting_contract.claim(erc20_token.contract_address); + + stop_cheat_block_timestamp_global(); + + let user_schedule = vesting_contract.get_user_vesting_schedule(RECIPIENT()); + + let curr_balance = erc20_token.balance_of(RECIPIENT()); + + assert!(curr_balance == prev_balance, "Current balance does not match previous balance"); + + assert!(claimable == 0, "Claimable amount is not zero as expected"); + + assert!(user_schedule.total_claimed == 0, "Total claimed amount is not zero as expected"); +} + +#[test] +fn test_user_can_claim_part_if_cliff_exceeded() { + let mut spy = spy_events(); + + let (vesting_contract, erc20_token) = add_schedule(); + let time_stamp = get_block_timestamp(); + let prev_balance = erc20_token.balance_of(RECIPIENT()); + + // duration is 20_000secs i.e cliff is set to 20% + start_cheat_block_timestamp_global(time_stamp + 5000); + + let claimable = vesting_contract.get_claimable_amount(RECIPIENT()); + + cheat_caller_address(vesting_contract.contract_address, RECIPIENT(), CheatSpan::TargetCalls(1)); + vesting_contract.claim(erc20_token.contract_address); + stop_cheat_block_timestamp_global(); + + let expected_event = Vesting::Event::SuccessfulClaim( + Vesting::SuccessfulClaim { + recipient: RECIPIENT(), token: erc20_token.contract_address, amount: claimable, + }, + ); + + spy.assert_emitted(@array![(vesting_contract.contract_address, expected_event)]); + let user_schedule = vesting_contract.get_user_vesting_schedule(RECIPIENT()); + + let curr_balance = erc20_token.balance_of(RECIPIENT()); + + assert!( + curr_balance >= prev_balance + claimable, "Recipient balance did not increase as expected", + ); + + assert!( + user_schedule.total_claimed >= claimable, + "Claimed amount is less than the claimable amount", + ); + + assert!( + user_schedule.total_claimed <= user_schedule.total_amount, + "Claimed amount exceeds the total allocation", + ); +} + +#[test] +fn test_user_can_claim_all_after_vesting_ended() { + let mut spy = spy_events(); + + let (vesting_contract, erc20_token) = add_schedule(); + let time_stamp = get_block_timestamp(); + let prev_balance = erc20_token.balance_of(RECIPIENT()); + + // duration is 20_000secs + start_cheat_block_timestamp_global(time_stamp + 21_000); + + let claimable = vesting_contract.get_claimable_amount(RECIPIENT()); + + cheat_caller_address(vesting_contract.contract_address, RECIPIENT(), CheatSpan::TargetCalls(1)); + vesting_contract.claim(erc20_token.contract_address); + stop_cheat_block_timestamp_global(); + + let expected_event = Vesting::Event::SuccessfulClaim( + Vesting::SuccessfulClaim { + recipient: RECIPIENT(), token: erc20_token.contract_address, amount: claimable, + }, + ); + + spy.assert_emitted(@array![(vesting_contract.contract_address, expected_event)]); + + let user_schedule = vesting_contract.get_user_vesting_schedule(RECIPIENT()); + + let curr_balance = erc20_token.balance_of(RECIPIENT()); + + assert!( + curr_balance == prev_balance + claimable, "Recipient balance did not increase as expected", + ); + + assert!( + user_schedule.total_claimed == claimable, + "Claimed amount is less than the claimable amount", + ); + + assert!( + user_schedule.total_claimed == user_schedule.total_amount, + "Claimed amount exceeds the total allocation", + ); +} + +#[test] +fn test_non_registered_user_cannot_claim() { + let (vesting_contract, erc20_token) = add_schedule(); + let time_stamp = get_block_timestamp(); + let prev_balance = erc20_token.balance_of(OTHER()); + + // duration is 20_000secs + start_cheat_block_timestamp_global(time_stamp + 20_000); + + let claimable = vesting_contract.get_claimable_amount(OTHER()); + + cheat_caller_address(vesting_contract.contract_address, OTHER(), CheatSpan::TargetCalls(1)); + vesting_contract.claim(erc20_token.contract_address); + stop_cheat_block_timestamp_global(); + + let user_schedule = vesting_contract.get_user_vesting_schedule(OTHER()); + + let curr_balance = erc20_token.balance_of(OTHER()); + + assert!(curr_balance == prev_balance, "Current balance does not match previous balance"); + + assert!(claimable == 0, "Claimable amount is not zero as expected"); + + assert!(user_schedule.total_claimed == 0, "Total claimed amount is not zero as expected"); +} + +#[test] +fn test_admin_can_remove_vesting() { + let (vesting_contract, erc20_token) = add_schedule(); + let time_stamp = get_block_timestamp(); + + start_cheat_block_timestamp_global(time_stamp + 2_000); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract.remove_schedule(erc20_token.contract_address, RECIPIENT(), OWNER()); + + stop_cheat_block_timestamp_global(); + + let user_schedule = vesting_contract.get_user_vesting_schedule(RECIPIENT()); + + assert!(user_schedule.total_amount == 0, "Expected total amount to be zero"); + + assert!(user_schedule.total_claimed == 0, "Expected total claimed amount to be zero"); + + assert!(user_schedule.recipient == ZERO_ADDRESS(), "Expected recipient to be the zero address"); + + assert!(user_schedule.token == ZERO_ADDRESS(), "Expected token to be the zero address"); + + assert!(user_schedule.cliff_time == 0, "Expected cliff time to be zero"); + + assert!(user_schedule.start_time == 0, "Expected start time to be zero"); + + assert!(user_schedule.end_time == 0, "Expected end time to be zero"); +} + +#[test] +fn test_admin_gets_full_refund_before_cliff() { + let (vesting_contract, erc20_token) = add_schedule(); + let prev_balance = erc20_token.balance_of(OWNER()); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract.remove_schedule(erc20_token.contract_address, RECIPIENT(), OWNER()); + + let curr_balance = erc20_token.balance_of(OWNER()); + + assert!(curr_balance == prev_balance + amount, "Owner balance did not increase as expecte"); +} + +#[test] +fn test_user_gets_claimable_after_cliff() { + let (vesting_contract, erc20_token) = add_schedule(); + let prev_balance = erc20_token.balance_of(RECIPIENT()); + let time_stamp = get_block_timestamp(); + + // cliff is 4_000 + start_cheat_block_timestamp_global(time_stamp + 6_000); + + let claimable = vesting_contract.get_claimable_amount(RECIPIENT()); + + cheat_caller_address(vesting_contract.contract_address, OWNER(), CheatSpan::TargetCalls(1)); + vesting_contract.remove_schedule(erc20_token.contract_address, RECIPIENT(), OWNER()); + + stop_cheat_block_timestamp_global(); + + let curr_balance = erc20_token.balance_of(RECIPIENT()); + + assert!( + curr_balance >= prev_balance + claimable, "Recipient balance did not increase as expecte", + ); +} + +#[test] +#[should_panic] +fn test_non_admin_cannot_remove_schedule() { + let (vesting_contract, erc20_token) = add_schedule(); + + cheat_caller_address( + vesting_contract.contract_address, OTHER_ADMIN(), CheatSpan::TargetCalls(1), + ); + vesting_contract.remove_schedule(erc20_token.contract_address, RECIPIENT(), OTHER_ADMIN()); +}