From 891419c5bba5dab81755582064e124a735915e12 Mon Sep 17 00:00:00 2001 From: Michael Birch Date: Fri, 12 Apr 2024 17:15:40 +0200 Subject: [PATCH 1/5] Feat(stateless-validation): Dynamically compute mandate price from target number of mandates per shard --- chain/epoch-manager/src/test_utils.rs | 9 +- .../epoch-manager/src/validator_selection.rs | 10 +- .../src/validator_mandates/compute_price.rs | 223 ++++++++++++++++++ .../mod.rs} | 92 ++++---- 4 files changed, 279 insertions(+), 55 deletions(-) create mode 100644 core/primitives/src/validator_mandates/compute_price.rs rename core/primitives/src/{validator_mandates.rs => validator_mandates/mod.rs} (87%) diff --git a/chain/epoch-manager/src/test_utils.rs b/chain/epoch-manager/src/test_utils.rs index bff31a2efdf..441e8d4903d 100644 --- a/chain/epoch-manager/src/test_utils.rs +++ b/chain/epoch-manager/src/test_utils.rs @@ -107,11 +107,12 @@ pub fn epoch_info_with_num_seats( }; let all_validators = account_to_validators(accounts); let validator_mandates = { - // TODO(#10014) determine required stake per mandate instead of reusing seat price. - // TODO(#10014) determine `min_mandates_per_shard` let num_shards = chunk_producers_settlement.len(); - let min_mandates_per_shard = 0; - let config = ValidatorMandatesConfig::new(seat_price, min_mandates_per_shard, num_shards); + let total_stake = + all_validators.iter().fold(0_u128, |acc, v| acc.saturating_add(v.stake())); + // For tests we estimate the target number of seats based on the seat price of the old algorithm. + let target_mandates_per_shard = (total_stake / seat_price) as usize; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); ValidatorMandates::new(config, &all_validators) }; EpochInfo::new( diff --git a/chain/epoch-manager/src/validator_selection.rs b/chain/epoch-manager/src/validator_selection.rs index e841bd9f397..d7f60c04a3b 100644 --- a/chain/epoch-manager/src/validator_selection.rs +++ b/chain/epoch-manager/src/validator_selection.rs @@ -187,11 +187,13 @@ pub fn proposals_to_epoch_info( }; let validator_mandates = if checked_feature!("stable", StatelessValidationV0, next_version) { - // TODO(#10014) determine required stake per mandate instead of reusing seat price. - // TODO(#10014) determine `min_mandates_per_shard` - let min_mandates_per_shard = 0; + // Value chosen based on calculations for the security of the protocol. + // With this number of mandates per shard and 6 shards, the theory calculations predict the + // protocol is secure for 40 years (at 90% confidence). + let target_mandates_per_shard = 68; + let num_shards = shard_ids.len(); let validator_mandates_config = - ValidatorMandatesConfig::new(threshold, min_mandates_per_shard, shard_ids.len()); + ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); // We can use `all_validators` to construct mandates Since a validator's position in // `all_validators` corresponds to its `ValidatorId` ValidatorMandates::new(validator_mandates_config, &all_validators) diff --git a/core/primitives/src/validator_mandates/compute_price.rs b/core/primitives/src/validator_mandates/compute_price.rs new file mode 100644 index 00000000000..1c392ac1828 --- /dev/null +++ b/core/primitives/src/validator_mandates/compute_price.rs @@ -0,0 +1,223 @@ +use {super::ValidatorMandatesConfig, near_primitives_core::types::Balance, std::cmp::Ordering}; + +/// Given the stakes for the validators and the target number of mandates to have, +/// this function computes the mandate price to use. It works by iterating a +/// function in an attempt to find its fixed point. This function is motived as follows: +/// Let the validator stakes be denoted by `s_i` a let `S = \sum_i s_i` be the total +/// stake. For a given mandate price `m` we can write each `s_i = m * q_i + r_i` +/// (by the Euclidean algorithm). Hence, the number of whole mandates created by +/// that price is equal to `\sum_i q_i`. If we set this number of whole mandates +/// equal to the target number `N` then substitute back in to the previous equations +/// we have `S = m * N + \sum_i r_i`. We can rearrange this to solve for `m`, +/// `m = (S - \sum_i r_i) / N`. Note that `r_i = a_i % m` so `m` is not truly +/// isolated, but rather the RHS is the expression we want to find the fixed point for. +pub fn compute_mandate_price(config: ValidatorMandatesConfig, stakes: F) -> Balance +where + I: Iterator, + F: Fn() -> I, +{ + let ValidatorMandatesConfig { target_mandates_per_shard, num_shards } = config; + let total_stake = saturating_sum(stakes()); + let target_mandates: u128 = num_shards.saturating_mul(target_mandates_per_shard) as u128; + + let initial_price = total_stake / target_mandates; + + // Function to compute the new estimated mandate price as well as + // evaluate the given mandate price. + let f = |price: u128| { + let mut whole_mandates = 0_u128; + let mut remainders = 0_u128; + for s in stakes() { + whole_mandates = whole_mandates.saturating_add(s / price); + remainders = remainders.saturating_add(s % price); + } + let updated_price = if total_stake > remainders { + (total_stake - remainders) / target_mandates + } else { + // This is an alternate expression we can try to find a fixed point of. + // We use it avoid making the next price equal to 0 (which is clearly incorrect). + // It is derived from `S = m * N + \sum_i r_i` by dividing by `m` first then + // isolating the `m` that appears on the LHS. + let partial_mandates = remainders / price; + total_stake / (target_mandates + partial_mandates) + }; + let mandate_diff = if whole_mandates > target_mandates { + whole_mandates - target_mandates + } else { + target_mandates - whole_mandates + }; + (PriceResult { price, mandate_diff }, updated_price) + }; + + // Iterate the function 25 times + let mut results = [PriceResult::default(); 25]; + let (result_0, mut price) = f(initial_price); + results[0] = result_0; + for result in results.iter_mut().skip(1) { + let (output, next_price) = f(price); + *result = output; + price = next_price; + } + + // Take the best result + let result = results.iter().min().expect("results iter is non-empty"); + result.price +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] +struct PriceResult { + price: u128, + mandate_diff: u128, +} + +impl PartialOrd for PriceResult { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PriceResult { + fn cmp(&self, other: &Self) -> Ordering { + match self.mandate_diff.cmp(&other.mandate_diff) { + Ordering::Equal => self.price.cmp(&other.price), + Ordering::Greater => Ordering::Greater, + Ordering::Less => Ordering::Less, + } + } +} + +fn saturating_sum>(iter: I) -> u128 { + iter.fold(0, |acc, x| acc.saturating_add(x)) +} + +#[cfg(test)] +mod tests { + use rand::{Rng, SeedableRng}; + + use super::*; + + // Test cases where all stakes are equal. + #[test] + fn test_constant_dist() { + let stakes = [11_u128; 13]; + let num_shards = 1; + let target_mandates_per_shard = stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + + // There are enough validators to have 1:1 correspondence with mandates. + assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), stakes[0]); + + let target_mandates_per_shard = 2 * stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + + // Now each validator needs to take two mandates. + assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), stakes[0] / 2); + + let target_mandates_per_shard = stakes.len() - 1; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + + // Now there are more validators than we need, but + // the mandate price still doesn't go below the common stake. + assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), stakes[0]); + } + + // Test cases where the stake distribution is a step function. + #[test] + fn test_step_dist() { + let stakes = { + let mut buf = [11_u128; 13]; + let n = buf.len() / 2; + for s in buf.iter_mut().take(n) { + *s *= 5; + } + buf + }; + let num_shards = 1; + let target_mandates_per_shard = stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + + // Computed price gives whole number of seats close to the target number + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 1); + + let target_mandates_per_shard = 2 * stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 8); + + let target_mandates_per_shard = stakes.len() / 2; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); + } + + // Test cases where the stake distribution is exponential. + #[test] + fn test_exp_dist() { + let stakes = { + let mut buf = vec![1_000_000_000_u128; 210]; + let mut last_stake = buf[0]; + for s in buf.iter_mut().skip(1) { + last_stake = last_stake * 97 / 100; + *s = last_stake; + } + buf + }; + + // This case is similar to the mainnet data. + let num_shards = 6; + let target_mandates_per_shard = 68; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard * num_shards); + + let num_shards = 1; + let target_mandates_per_shard = stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); + + let target_mandates_per_shard = stakes.len() * 2; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); + + let target_mandates_per_shard = stakes.len() / 2; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); + } + + // Test cases where the stakes are chosen uniformly at random. + #[test] + fn test_rand_dist() { + let stakes = { + let mut stakes = vec![0_u128; 1000]; + let mut rng = rand::rngs::StdRng::seed_from_u64(0xdeadbeef); + for s in stakes.iter_mut() { + *s = rng.gen_range(1_u128..10_000u128); + } + stakes + }; + + let num_shards = 1; + let target_mandates_per_shard = stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 21); + + let target_mandates_per_shard = 2 * stakes.len(); + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); + + let target_mandates_per_shard = stakes.len() / 2; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + let price = compute_mandate_price(config, || stakes.iter().copied()); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 31); + } + + fn count_whole_mandates(stakes: &[u128], mandate_price: u128) -> usize { + saturating_sum(stakes.iter().map(|s| *s / mandate_price)) as usize + } +} diff --git a/core/primitives/src/validator_mandates.rs b/core/primitives/src/validator_mandates/mod.rs similarity index 87% rename from core/primitives/src/validator_mandates.rs rename to core/primitives/src/validator_mandates/mod.rs index 1df21995b02..64ef7763375 100644 --- a/core/primitives/src/validator_mandates.rs +++ b/core/primitives/src/validator_mandates/mod.rs @@ -6,16 +6,16 @@ use itertools::Itertools; use near_primitives_core::types::Balance; use rand::{seq::SliceRandom, Rng}; +mod compute_price; + /// Represents the configuration of [`ValidatorMandates`]. Its parameters are expected to remain /// valid for one epoch. #[derive( BorshSerialize, BorshDeserialize, Default, Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, )] pub struct ValidatorMandatesConfig { - /// The amount of stake that corresponds to one mandate. - stake_per_mandate: Balance, - /// The minimum number of mandates required per shard. - min_mandates_per_shard: usize, + /// The desired number of mandates required per shard. + target_mandates_per_shard: usize, /// The number of shards for the referenced epoch. num_shards: usize, } @@ -29,14 +29,9 @@ impl ValidatorMandatesConfig { /// /// - If `stake_per_mandate` is 0 as this would lead to division by 0. /// - If `num_shards` is zero. - pub fn new( - stake_per_mandate: Balance, - min_mandates_per_shard: usize, - num_shards: usize, - ) -> Self { - assert!(stake_per_mandate > 0, "stake_per_mandate of 0 would lead to division by 0"); + pub fn new(target_mandates_per_shard: usize, num_shards: usize) -> Self { assert!(num_shards > 0, "there should be at least one shard"); - Self { stake_per_mandate, min_mandates_per_shard, num_shards } + Self { target_mandates_per_shard, num_shards } } } @@ -54,6 +49,8 @@ impl ValidatorMandatesConfig { pub struct ValidatorMandates { /// The configuration applied to the mandates. config: ValidatorMandatesConfig, + /// The amount of stake a whole mandate is worth. + stake_per_mandate: Balance, /// Each element represents a validator mandate held by the validator with the given id. /// /// The id of a validator who holds `n >= 0` mandates occurs `n` times in the vector. @@ -75,8 +72,10 @@ impl ValidatorMandates { /// Only full mandates are assigned, partial mandates are dropped. For example, when the stake /// required for a mandate is 5 and a validator has staked 12, then it will obtain 2 mandates. pub fn new(config: ValidatorMandatesConfig, validators: &[ValidatorStake]) -> Self { + let stake_per_mandate = + compute_price::compute_mandate_price(config, || validators.iter().map(|v| v.stake())); let num_mandates_per_validator: Vec = - validators.iter().map(|v| v.num_mandates(config.stake_per_mandate)).collect(); + validators.iter().map(|v| v.num_mandates(stake_per_mandate)).collect(); let num_total_mandates = num_mandates_per_validator.iter().map(|&num| usize::from(num)).sum(); let mut mandates: Vec = Vec::with_capacity(num_total_mandates); @@ -88,16 +87,6 @@ impl ValidatorMandates { } } - let required_mandates = config.min_mandates_per_shard * config.num_shards; - if mandates.len() < required_mandates { - // TODO(#10014) dynamically lower `stake_per_mandate` to reach enough mandates - panic!( - "not enough validator mandates: got {}, need {}", - mandates.len(), - required_mandates - ); - } - // Not counting partials towards `required_mandates` as the weight of partials and its // distribution across shards may vary widely. // @@ -105,13 +94,13 @@ impl ValidatorMandates { // divided by `config.stake_per_mandate`, i.e. some validators will have partials. let mut partials = Vec::with_capacity(validators.len()); for i in 0..validators.len() { - let partial_weight = validators[i].partial_mandate_weight(config.stake_per_mandate); + let partial_weight = validators[i].partial_mandate_weight(stake_per_mandate); if partial_weight > 0 { partials.push((i as ValidatorId, partial_weight)); } } - Self { config, mandates, partials } + Self { config, stake_per_mandate, mandates, partials } } /// Returns a validator assignment obtained by shuffling mandates and assigning them to shards. @@ -139,7 +128,7 @@ impl ValidatorMandates { // // Assume, for example, there are 10 mandates and 4 shards. Then for `shard_id = 1` we // collect the mandates with indices 1, 5, and 9. - let stake_per_mandate = self.config.stake_per_mandate; + let stake_per_mandate = self.stake_per_mandate; let mut stake_assignment_per_shard = vec![HashMap::new(); self.config.num_shards]; for shard_id in 0..self.config.num_shards { // Achieve shard id shuffling by writing to the position of the alias of `shard_id`. @@ -278,12 +267,11 @@ mod tests { #[test] fn test_validator_mandates_config_new() { - let stake_per_mandate = 10; - let min_mandates_per_shard = 400; + let target_mandates_per_shard = 400; let num_shards = 4; assert_eq!( - ValidatorMandatesConfig::new(stake_per_mandate, min_mandates_per_shard, num_shards), - ValidatorMandatesConfig { stake_per_mandate, min_mandates_per_shard, num_shards }, + ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards), + ValidatorMandatesConfig { target_mandates_per_shard, num_shards }, ) } @@ -319,18 +307,28 @@ mod tests { #[test] fn test_validator_mandates_new() { let validators = new_validator_stakes(); - let config = ValidatorMandatesConfig::new(10, 1, 4); + let config = ValidatorMandatesConfig::new(3, 4); let mandates = ValidatorMandates::new(config, &validators); - // At 10 stake per mandate, the first validator holds three mandates, and so on. - // Note that "account_2" holds no mandate as its stake is below the threshold. - let expected_mandates: Vec = vec![0, 0, 0, 1, 1, 3, 4, 4, 4]; + // With 3 mandates per shard and 4 shards, we are looking for around 12 total mandates. + // The total stake in `new_validator_stakes` is 123, so to get 12 mandates we need a price + // close to 10. But the algorithm for computing price tries to make the number of _whole_ + // mandates equal to 12, and there are validators with partial mandates in the distribution, + // therefore the price is set a little lower than 10. + assert_eq!(mandates.stake_per_mandate, 8); + + // At 8 stake per mandate, the first validator holds three mandates, and so on. + // Note that "account_5" and "account_6" hold no mandate as both their stakes are below the threshold. + let expected_mandates: Vec = vec![0, 0, 0, 1, 1, 1, 2, 3, 4, 4, 4, 4]; assert_eq!(mandates.mandates, expected_mandates); - // At 10 stake per mandate, the first validator holds no partial mandate, the second - // validator holds a partial mandate with weight 7, and so on. + // The number of whole mandates is exactly equal to our target + assert_eq!(mandates.mandates.len(), config.num_shards * config.target_mandates_per_shard); + + // At 8 stake per mandate, the first validator a partial mandate with weight 6, the second + // validator holds a partial mandate with weight 3, and so on. let expected_partials: Vec<(ValidatorId, Balance)> = - vec![(1, 7), (2, 9), (3, 2), (4, 5), (5, 4), (6, 6)]; + vec![(0, 6), (1, 3), (2, 1), (3, 4), (4, 3), (5, 4), (6, 6)]; assert_eq!(mandates.partials, expected_partials); } @@ -338,7 +336,7 @@ mod tests { fn test_validator_mandates_shuffled_mandates() { // Testing with different `num_shards` values to verify the shuffles used in other tests. assert_validator_mandates_shuffled_mandates(3, vec![0, 1, 4, 4, 3, 1, 4, 0, 0]); - assert_validator_mandates_shuffled_mandates(4, vec![0, 4, 1, 1, 0, 0, 4, 3, 4]); + assert_validator_mandates_shuffled_mandates(4, vec![0, 0, 2, 1, 3, 4, 1, 1, 0, 4, 4, 4]); } fn assert_validator_mandates_shuffled_mandates( @@ -346,7 +344,7 @@ mod tests { expected_assignment: Vec, ) { let validators = new_validator_stakes(); - let config = ValidatorMandatesConfig::new(10, 1, num_shards); + let config = ValidatorMandatesConfig::new(3, num_shards); let mandates = ValidatorMandates::new(config, &validators); let mut rng = new_fixed_rng(); @@ -371,7 +369,7 @@ mod tests { ); assert_validator_mandates_shuffled_partials( 4, - vec![(5, 4), (4, 5), (1, 7), (3, 2), (2, 9), (6, 6)], + vec![(5, 4), (3, 4), (0, 6), (2, 1), (1, 3), (4, 3), (6, 6)], ); } @@ -380,7 +378,7 @@ mod tests { expected_assignment: Vec<(ValidatorId, Balance)>, ) { let validators = new_validator_stakes(); - let config = ValidatorMandatesConfig::new(10, 1, num_shards); + let config = ValidatorMandatesConfig::new(3, num_shards); let mandates = ValidatorMandates::new(config, &validators); let mut rng = new_fixed_rng(); @@ -405,7 +403,7 @@ mod tests { // Assignments in `test_validator_mandates_shuffled_*` can be used to construct // `expected_assignment` below. // Note that shard ids are shuffled too, see `test_shuffled_shard_ids_new`. - let config = ValidatorMandatesConfig::new(10, 1, 3); + let config = ValidatorMandatesConfig::new(3, 3); let expected_assignment = vec![ vec![(1, 17), (4, 10), (6, 06), (0, 10)], vec![(4, 05), (5, 04), (0, 10), (1, 10), (3, 10)], @@ -422,12 +420,12 @@ mod tests { // Assignments in `test_validator_mandates_shuffled_*` can be used to construct // `expected_assignment` below. // Note that shard ids are shuffled too, see `test_shuffled_shard_ids_new`. - let config = ValidatorMandatesConfig::new(10, 1, 4); + let config = ValidatorMandatesConfig::new(3, 4); let expected_mandates_per_shards = vec![ - vec![(1, 07), (4, 10), (0, 20)], - vec![(4, 10), (3, 02), (1, 10)], - vec![(0, 10), (4, 15), (6, 06)], - vec![(3, 10), (5, 04), (2, 09), (1, 10)], + vec![(3, 8), (6, 6), (0, 22)], + vec![(4, 8), (2, 9), (1, 08)], + vec![(0, 8), (3, 4), (4, 19)], + vec![(4, 8), (5, 4), (1, 19)], ]; assert_validator_mandates_sample(config, expected_mandates_per_shards); } @@ -483,7 +481,7 @@ mod tests { #[test] fn test_deterministic_shuffle() { - let config = ValidatorMandatesConfig::new(10, 1, 4); + let config = ValidatorMandatesConfig::new(3, 4); let validators = new_validator_stakes(); let mandates = ValidatorMandates::new(config, &validators); From b2b08de515c887d322ce00ebb9f3eeb5d5a1468a Mon Sep 17 00:00:00 2001 From: Michael Birch Date: Fri, 12 Apr 2024 19:01:48 +0200 Subject: [PATCH 2/5] Fix: handle case where total stake is very small --- .../epoch-manager/src/validator_selection.rs | 8 +++--- .../src/validator_mandates/compute_price.rs | 27 +++++++++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/chain/epoch-manager/src/validator_selection.rs b/chain/epoch-manager/src/validator_selection.rs index d7f60c04a3b..6e2d23a1b7a 100644 --- a/chain/epoch-manager/src/validator_selection.rs +++ b/chain/epoch-manager/src/validator_selection.rs @@ -838,10 +838,10 @@ mod tests { // Given `epoch_info` and `proposals` above, the sample at a given height is deterministic. let height = 42; let expected_assignments = vec![ - vec![(1, 300), (0, 300), (2, 300), (3, 60)], - vec![(0, 600), (2, 200), (1, 200)], - vec![(3, 200), (2, 300), (1, 100), (0, 400)], - vec![(2, 200), (4, 140), (1, 400), (0, 200)], + vec![(4, 56), (1, 168), (2, 300), (3, 84), (0, 364)], + vec![(3, 70), (1, 300), (4, 42), (2, 266), (0, 308)], + vec![(4, 42), (1, 238), (3, 42), (0, 450), (2, 196)], + vec![(2, 238), (1, 294), (3, 64), (0, 378)], ]; assert_eq!(epoch_info.sample_chunk_validators(height), expected_assignments); } diff --git a/core/primitives/src/validator_mandates/compute_price.rs b/core/primitives/src/validator_mandates/compute_price.rs index 1c392ac1828..7e37ece2206 100644 --- a/core/primitives/src/validator_mandates/compute_price.rs +++ b/core/primitives/src/validator_mandates/compute_price.rs @@ -1,4 +1,8 @@ -use {super::ValidatorMandatesConfig, near_primitives_core::types::Balance, std::cmp::Ordering}; +use { + super::ValidatorMandatesConfig, + near_primitives_core::types::Balance, + std::cmp::{min, Ordering}, +}; /// Given the stakes for the validators and the target number of mandates to have, /// this function computes the mandate price to use. It works by iterating a @@ -18,7 +22,13 @@ where { let ValidatorMandatesConfig { target_mandates_per_shard, num_shards } = config; let total_stake = saturating_sum(stakes()); - let target_mandates: u128 = num_shards.saturating_mul(target_mandates_per_shard) as u128; + + // The target number of mandates cannot be larger than the total amount of stake. + // In production the total stake is _much_ higher than + // `num_shards * target_mandates_per_shard`, but in tests validators are given + // low staked numbers, so we need to have this condition in place. + let target_mandates: u128 = + min(num_shards.saturating_mul(target_mandates_per_shard) as u128, total_stake); let initial_price = total_stake / target_mandates; @@ -96,6 +106,19 @@ mod tests { use super::*; + // Test case where the target number of mandates is larger than the total stake. + // This should never happen in production, but nearcore tests sometimes have + // low stake. + #[test] + fn test_small_total_stake() { + let stakes = [100_u128; 1]; + let num_shards = 1; + let target_mandates_per_shard = 1000; + let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); + + assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), 1); + } + // Test cases where all stakes are equal. #[test] fn test_constant_dist() { From 1f2f35c4af21bb0dad75f3c56e32476f1f6ff561 Mon Sep 17 00:00:00 2001 From: Michael Birch Date: Wed, 17 Apr 2024 15:34:03 +0200 Subject: [PATCH 3/5] Simplify compute_mandate_price function signature --- .../src/validator_mandates/compute_price.rs | 38 +++++++++---------- core/primitives/src/validator_mandates/mod.rs | 4 +- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/core/primitives/src/validator_mandates/compute_price.rs b/core/primitives/src/validator_mandates/compute_price.rs index 7e37ece2206..c91bbcd88dd 100644 --- a/core/primitives/src/validator_mandates/compute_price.rs +++ b/core/primitives/src/validator_mandates/compute_price.rs @@ -15,13 +15,9 @@ use { /// we have `S = m * N + \sum_i r_i`. We can rearrange this to solve for `m`, /// `m = (S - \sum_i r_i) / N`. Note that `r_i = a_i % m` so `m` is not truly /// isolated, but rather the RHS is the expression we want to find the fixed point for. -pub fn compute_mandate_price(config: ValidatorMandatesConfig, stakes: F) -> Balance -where - I: Iterator, - F: Fn() -> I, -{ +pub fn compute_mandate_price(config: ValidatorMandatesConfig, stakes: &[Balance]) -> Balance { let ValidatorMandatesConfig { target_mandates_per_shard, num_shards } = config; - let total_stake = saturating_sum(stakes()); + let total_stake = saturating_sum(stakes.iter().copied()); // The target number of mandates cannot be larger than the total amount of stake. // In production the total stake is _much_ higher than @@ -37,7 +33,7 @@ where let f = |price: u128| { let mut whole_mandates = 0_u128; let mut remainders = 0_u128; - for s in stakes() { + for s in stakes.iter().copied() { whole_mandates = whole_mandates.saturating_add(s / price); remainders = remainders.saturating_add(s % price); } @@ -116,7 +112,7 @@ mod tests { let target_mandates_per_shard = 1000; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), 1); + assert_eq!(compute_mandate_price(config, &stakes), 1); } // Test cases where all stakes are equal. @@ -128,20 +124,20 @@ mod tests { let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); // There are enough validators to have 1:1 correspondence with mandates. - assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), stakes[0]); + assert_eq!(compute_mandate_price(config, &stakes), stakes[0]); let target_mandates_per_shard = 2 * stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); // Now each validator needs to take two mandates. - assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), stakes[0] / 2); + assert_eq!(compute_mandate_price(config, &stakes), stakes[0] / 2); let target_mandates_per_shard = stakes.len() - 1; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); // Now there are more validators than we need, but // the mandate price still doesn't go below the common stake. - assert_eq!(compute_mandate_price(config, || stakes.iter().copied()), stakes[0]); + assert_eq!(compute_mandate_price(config, &stakes), stakes[0]); } // Test cases where the stake distribution is a step function. @@ -160,17 +156,17 @@ mod tests { let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); // Computed price gives whole number of seats close to the target number - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 1); let target_mandates_per_shard = 2 * stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 8); let target_mandates_per_shard = stakes.len() / 2; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); } @@ -191,23 +187,23 @@ mod tests { let num_shards = 6; let target_mandates_per_shard = 68; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard * num_shards); let num_shards = 1; let target_mandates_per_shard = stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); let target_mandates_per_shard = stakes.len() * 2; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); let target_mandates_per_shard = stakes.len() / 2; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); } @@ -226,17 +222,17 @@ mod tests { let num_shards = 1; let target_mandates_per_shard = stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 21); let target_mandates_per_shard = 2 * stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); let target_mandates_per_shard = stakes.len() / 2; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); - let price = compute_mandate_price(config, || stakes.iter().copied()); + let price = compute_mandate_price(config, &stakes); assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 31); } diff --git a/core/primitives/src/validator_mandates/mod.rs b/core/primitives/src/validator_mandates/mod.rs index 64ef7763375..5e936b6492d 100644 --- a/core/primitives/src/validator_mandates/mod.rs +++ b/core/primitives/src/validator_mandates/mod.rs @@ -72,8 +72,8 @@ impl ValidatorMandates { /// Only full mandates are assigned, partial mandates are dropped. For example, when the stake /// required for a mandate is 5 and a validator has staked 12, then it will obtain 2 mandates. pub fn new(config: ValidatorMandatesConfig, validators: &[ValidatorStake]) -> Self { - let stake_per_mandate = - compute_price::compute_mandate_price(config, || validators.iter().map(|v| v.stake())); + let stakes: Vec = validators.iter().map(|v| v.stake()).collect(); + let stake_per_mandate = compute_price::compute_mandate_price(config, &stakes); let num_mandates_per_validator: Vec = validators.iter().map(|v| v.num_mandates(stake_per_mandate)).collect(); let num_total_mandates = From 5ca2a4864f0e9bbd84fc1859893972a56d5ea5d8 Mon Sep 17 00:00:00 2001 From: Michael Birch Date: Wed, 17 Apr 2024 16:03:54 +0200 Subject: [PATCH 4/5] Implement compute_mandate_price using a binary search --- .../src/validator_mandates/compute_price.rs | 144 +++++++++--------- 1 file changed, 75 insertions(+), 69 deletions(-) diff --git a/core/primitives/src/validator_mandates/compute_price.rs b/core/primitives/src/validator_mandates/compute_price.rs index c91bbcd88dd..7fa7434f452 100644 --- a/core/primitives/src/validator_mandates/compute_price.rs +++ b/core/primitives/src/validator_mandates/compute_price.rs @@ -5,16 +5,7 @@ use { }; /// Given the stakes for the validators and the target number of mandates to have, -/// this function computes the mandate price to use. It works by iterating a -/// function in an attempt to find its fixed point. This function is motived as follows: -/// Let the validator stakes be denoted by `s_i` a let `S = \sum_i s_i` be the total -/// stake. For a given mandate price `m` we can write each `s_i = m * q_i + r_i` -/// (by the Euclidean algorithm). Hence, the number of whole mandates created by -/// that price is equal to `\sum_i q_i`. If we set this number of whole mandates -/// equal to the target number `N` then substitute back in to the previous equations -/// we have `S = m * N + \sum_i r_i`. We can rearrange this to solve for `m`, -/// `m = (S - \sum_i r_i) / N`. Note that `r_i = a_i % m` so `m` is not truly -/// isolated, but rather the RHS is the expression we want to find the fixed point for. +/// this function computes the mandate price to use. It works by using a binary search. pub fn compute_mandate_price(config: ValidatorMandatesConfig, stakes: &[Balance]) -> Balance { let ValidatorMandatesConfig { target_mandates_per_shard, num_shards } = config; let total_stake = saturating_sum(stakes.iter().copied()); @@ -26,70 +17,85 @@ pub fn compute_mandate_price(config: ValidatorMandatesConfig, stakes: &[Balance] let target_mandates: u128 = min(num_shards.saturating_mul(target_mandates_per_shard) as u128, total_stake); - let initial_price = total_stake / target_mandates; + // Note: the reason to have the binary search look for the largest mandate price + // which obtains the target number of whole mandates is because the largest value + // minimizes the partial mandates. This can be seen as follows: + // Let `s_i` be the ith stake, `T` be the total stake and `m` be the mandate price. + // T / m = \sum (s_i / m) = \sum q_i + \sum r_i + // ==> \sum q_i = (T / m) - \sum r_i [Eq. (1)] + // where `s_i = m * q_i + r_i` is obtained by the Euclidean algorithm. + // Notice that the LHS of (1) is the number of whole mandates, which we + // are assuming is equal to our target value for some range of `m` values. + // When we use a larger `m` value, `T / m` decreases but we need the LHS + // to remain constant, therefore `\sum r_i` must also decrease. + binary_search(1, total_stake, target_mandates, |mandate_price| { + saturating_sum(stakes.iter().map(|s| *s / mandate_price)) + }) +} - // Function to compute the new estimated mandate price as well as - // evaluate the given mandate price. - let f = |price: u128| { - let mut whole_mandates = 0_u128; - let mut remainders = 0_u128; - for s in stakes.iter().copied() { - whole_mandates = whole_mandates.saturating_add(s / price); - remainders = remainders.saturating_add(s % price); - } - let updated_price = if total_stake > remainders { - (total_stake - remainders) / target_mandates - } else { - // This is an alternate expression we can try to find a fixed point of. - // We use it avoid making the next price equal to 0 (which is clearly incorrect). - // It is derived from `S = m * N + \sum_i r_i` by dividing by `m` first then - // isolating the `m` that appears on the LHS. - let partial_mandates = remainders / price; - total_stake / (target_mandates + partial_mandates) - }; - let mandate_diff = if whole_mandates > target_mandates { - whole_mandates - target_mandates - } else { - target_mandates - whole_mandates - }; - (PriceResult { price, mandate_diff }, updated_price) - }; - - // Iterate the function 25 times - let mut results = [PriceResult::default(); 25]; - let (result_0, mut price) = f(initial_price); - results[0] = result_0; - for result in results.iter_mut().skip(1) { - let (output, next_price) = f(price); - *result = output; - price = next_price; +/// Assume `f` is a non-increasing function (f(x) <= f(y) if x > y) and `low < high`. +/// This function uses a binary search to attempt to find the largest input, `x` such that +/// `f(x) == target`, `low <= x` and `x <= high`. +/// If there is no such `x` then it will return the unique input `x` such that +/// `f(x) > target`, `f(x + 1) < target`, `low <= x` and `x <= high`. +fn binary_search(low: Balance, high: Balance, target: u128, f: F) -> Balance +where + F: Fn(Balance) -> u128, +{ + debug_assert!(low < high); + + let mut low = low; + let mut high = high; + + if f(low) == target { + return highest_exact(low, high, target, f); + } else if f(high) == target { + // No need to use `highest_exact` here because we are already at the upper bound. + return high; } - // Take the best result - let result = results.iter().min().expect("results iter is non-empty"); - result.price -} + while high - low > 1 { + let mid = low + (high - low) / 2; + let f_mid = f(mid); -#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] -struct PriceResult { - price: u128, - mandate_diff: u128, -} - -impl PartialOrd for PriceResult { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + match f_mid.cmp(&target) { + Ordering::Equal => return highest_exact(mid, high, target, f), + Ordering::Less => high = mid, + Ordering::Greater => low = mid, + } } + + // No exact answer, return best price which gives an answer greater than + // `target_mandates` (which is `low` because `count_whole_mandates` is a non-increasing function). + low } -impl Ord for PriceResult { - fn cmp(&self, other: &Self) -> Ordering { - match self.mandate_diff.cmp(&other.mandate_diff) { - Ordering::Equal => self.price.cmp(&other.price), - Ordering::Greater => Ordering::Greater, - Ordering::Less => Ordering::Less, +/// Assume `f` is a non-increasing function (f(x) <= f(y) if x > y), `f(low) == target` +/// and `f(high) < target`. This function uses a binary search to find the largest input, `x` +/// such that `f(x) == target`. +fn highest_exact(low: Balance, high: Balance, target: u128, f: F) -> Balance +where + F: Fn(Balance) -> u128, +{ + debug_assert!(low < high); + debug_assert_eq!(f(low), target); + debug_assert!(f(high) < target); + + let mut low = low; + let mut high = high; + + while high - low > 1 { + let mid = low + (high - low) / 2; + let f_mid = f(mid); + + match f_mid.cmp(&target) { + Ordering::Equal => low = mid, + Ordering::Less => high = mid, + Ordering::Greater => unreachable!("Given function must be non-increasing"), } } + + low } fn saturating_sum>(iter: I) -> u128 { @@ -157,12 +163,12 @@ mod tests { // Computed price gives whole number of seats close to the target number let price = compute_mandate_price(config, &stakes); - assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 1); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 5); let target_mandates_per_shard = 2 * stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); let price = compute_mandate_price(config, &stakes); - assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 8); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 11); let target_mandates_per_shard = stakes.len() / 2; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); @@ -223,7 +229,7 @@ mod tests { let target_mandates_per_shard = stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); let price = compute_mandate_price(config, &stakes); - assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 21); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 3); let target_mandates_per_shard = 2 * stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); @@ -233,7 +239,7 @@ mod tests { let target_mandates_per_shard = stakes.len() / 2; let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); let price = compute_mandate_price(config, &stakes); - assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard - 31); + assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard); } fn count_whole_mandates(stakes: &[u128], mandate_price: u128) -> usize { From 532663027578c53a4afd29ca4ff49906cfee4e78 Mon Sep 17 00:00:00 2001 From: Michael Birch Date: Thu, 18 Apr 2024 23:30:36 +0200 Subject: [PATCH 5/5] Comment on not getting the exact target number of mandates in test_rand_dist --- core/primitives/src/validator_mandates/compute_price.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/primitives/src/validator_mandates/compute_price.rs b/core/primitives/src/validator_mandates/compute_price.rs index 7fa7434f452..f54fdfb75cb 100644 --- a/core/primitives/src/validator_mandates/compute_price.rs +++ b/core/primitives/src/validator_mandates/compute_price.rs @@ -229,6 +229,10 @@ mod tests { let target_mandates_per_shard = stakes.len(); let config = ValidatorMandatesConfig::new(target_mandates_per_shard, num_shards); let price = compute_mandate_price(config, &stakes); + // In this case it was not possible to find a seat price that exactly results + // in the target number of mandates. This is simply due to the discrete nature + // of the problem. But the algorithm still gets very close (3 out of 1000 is + // 0.3% off the target). assert_eq!(count_whole_mandates(&stakes, price), target_mandates_per_shard + 3); let target_mandates_per_shard = 2 * stakes.len();