Skip to content

Commit

Permalink
Add ensure_pow method (paritytech#13042)
Browse files Browse the repository at this point in the history
* add ensure_pow method

* reexport checked_pow and ensure_pow
  • Loading branch information
lemunozm authored and ltfschoen committed Feb 22, 2023
1 parent 0dc44b9 commit 98b81aa
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
52 changes: 42 additions & 10 deletions primitives/arithmetic/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

use codec::HasCompact;
pub use ensure::{
Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign, EnsureFixedPointNumber,
EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp, EnsureOpAssign, EnsureSub,
EnsureSubAssign,
ensure_pow, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign,
EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp,
EnsureOpAssign, EnsureSub, EnsureSubAssign,
};
pub use integer_sqrt::IntegerSquareRoot;
pub use num_traits::{
Expand Down Expand Up @@ -342,7 +342,7 @@ impl<T: Sized> SaturatedConversion for T {}
/// The *EnsureOps* family functions follows the same behavior as *CheckedOps* but
/// returning an [`ArithmeticError`](crate::ArithmeticError) instead of `None`.
mod ensure {
use super::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, Zero};
use super::{checked_pow, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Zero};
use crate::{ArithmeticError, FixedPointNumber, FixedPointOperand};

/// Performs addition that returns [`ArithmeticError`] instead of wrapping around on overflow.
Expand Down Expand Up @@ -511,6 +511,27 @@ mod ensure {
}
}

/// Raises a value to the power of exp, returning `ArithmeticError` if an overflow occurred.
///
/// Check [`checked_pow`] for more info about border cases.
///
/// ```
/// use sp_arithmetic::{traits::ensure_pow, ArithmeticError};
///
/// fn overflow() -> Result<(), ArithmeticError> {
/// ensure_pow(2u64, 64)?;
/// Ok(())
/// }
///
/// assert_eq!(overflow(), Err(ArithmeticError::Overflow));
/// ```
pub fn ensure_pow<T: One + CheckedMul + Clone>(
base: T,
exp: usize,
) -> Result<T, ArithmeticError> {
checked_pow(base, exp).ok_or(ArithmeticError::Overflow)
}

impl<T: EnsureAddAssign> EnsureAdd for T {}
impl<T: EnsureSubAssign> EnsureSub for T {}
impl<T: EnsureMulAssign> EnsureMul for T {}
Expand Down Expand Up @@ -953,6 +974,15 @@ mod tests {
test_ensure(values(), &EnsureDiv::ensure_div, &CheckedDiv::checked_div);
}

#[test]
fn ensure_pow_works() {
test_ensure(
values().into_iter().map(|(base, exp)| (base, exp as usize)).collect(),
ensure_pow,
|&a, &b| checked_pow(a, b),
);
}

#[test]
fn ensure_add_assign_works() {
test_ensure_assign(values(), &EnsureAddAssign::ensure_add_assign, &EnsureAdd::ensure_add);
Expand All @@ -974,11 +1004,12 @@ mod tests {
}

/// Test that the ensured function returns the expected un-ensured value.
fn test_ensure<V, E, P>(pairs: Vec<(V, V)>, ensured: E, unensured: P)
fn test_ensure<V, W, E, P>(pairs: Vec<(V, W)>, ensured: E, unensured: P)
where
V: Ensure + core::fmt::Debug + Copy,
E: Fn(V, V) -> Result<V, ArithmeticError>,
P: Fn(&V, &V) -> Option<V>,
W: Ensure + core::fmt::Debug + Copy,
E: Fn(V, W) -> Result<V, ArithmeticError>,
P: Fn(&V, &W) -> Option<V>,
{
for (a, b) in pairs.into_iter() {
match ensured(a, b) {
Expand All @@ -993,11 +1024,12 @@ mod tests {
}

/// Test that the ensured function modifies `self` to the expected un-ensured value.
fn test_ensure_assign<V, E, P>(pairs: Vec<(V, V)>, ensured: E, unensured: P)
fn test_ensure_assign<V, W, E, P>(pairs: Vec<(V, W)>, ensured: E, unensured: P)
where
V: Ensure + std::panic::RefUnwindSafe + std::panic::UnwindSafe + core::fmt::Debug + Copy,
E: Fn(&mut V, V) -> Result<(), ArithmeticError>,
P: Fn(V, V) -> Result<V, ArithmeticError> + std::panic::RefUnwindSafe,
W: Ensure + std::panic::RefUnwindSafe + std::panic::UnwindSafe + core::fmt::Debug + Copy,
E: Fn(&mut V, W) -> Result<(), ArithmeticError>,
P: Fn(V, W) -> Result<V, ArithmeticError> + std::panic::RefUnwindSafe,
{
for (mut a, b) in pairs.into_iter() {
let old_a = a;
Expand Down
10 changes: 5 additions & 5 deletions primitives/runtime/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ use impl_trait_for_tuples::impl_for_tuples;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sp_application_crypto::AppKey;
pub use sp_arithmetic::traits::{
AtLeast32Bit, AtLeast32BitUnsigned, Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedShl,
CheckedShr, CheckedSub, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign,
EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp,
EnsureOpAssign, EnsureSub, EnsureSubAssign, IntegerSquareRoot, One, SaturatedConversion,
Saturating, UniqueSaturatedFrom, UniqueSaturatedInto, Zero,
checked_pow, ensure_pow, AtLeast32Bit, AtLeast32BitUnsigned, Bounded, CheckedAdd, CheckedDiv,
CheckedMul, CheckedShl, CheckedShr, CheckedSub, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv,
EnsureDivAssign, EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign,
EnsureOp, EnsureOpAssign, EnsureSub, EnsureSubAssign, IntegerSquareRoot, One,
SaturatedConversion, Saturating, UniqueSaturatedFrom, UniqueSaturatedInto, Zero,
};
use sp_core::{self, storage::StateVersion, Hasher, RuntimeDebug, TypeId};
#[doc(hidden)]
Expand Down

0 comments on commit 98b81aa

Please sign in to comment.