diff --git a/primitives/arithmetic/src/traits.rs b/primitives/arithmetic/src/traits.rs index 9e9dff8e6f9ea..dfba046754373 100644 --- a/primitives/arithmetic/src/traits.rs +++ b/primitives/arithmetic/src/traits.rs @@ -19,9 +19,9 @@ use codec::HasCompact; pub use ensure::{ - Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign, EnsureFixedPointNumber, - EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp, EnsureOpAssign, EnsureSub, - EnsureSubAssign, + ensure_pow, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign, + EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp, + EnsureOpAssign, EnsureSub, EnsureSubAssign, }; pub use integer_sqrt::IntegerSquareRoot; pub use num_traits::{ @@ -342,7 +342,7 @@ impl SaturatedConversion for T {} /// The *EnsureOps* family functions follows the same behavior as *CheckedOps* but /// returning an [`ArithmeticError`](crate::ArithmeticError) instead of `None`. mod ensure { - use super::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, Zero}; + use super::{checked_pow, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Zero}; use crate::{ArithmeticError, FixedPointNumber, FixedPointOperand}; /// Performs addition that returns [`ArithmeticError`] instead of wrapping around on overflow. @@ -511,6 +511,27 @@ mod ensure { } } + /// Raises a value to the power of exp, returning `ArithmeticError` if an overflow occurred. + /// + /// Check [`checked_pow`] for more info about border cases. + /// + /// ``` + /// use sp_arithmetic::{traits::ensure_pow, ArithmeticError}; + /// + /// fn overflow() -> Result<(), ArithmeticError> { + /// ensure_pow(2u64, 64)?; + /// Ok(()) + /// } + /// + /// assert_eq!(overflow(), Err(ArithmeticError::Overflow)); + /// ``` + pub fn ensure_pow( + base: T, + exp: usize, + ) -> Result { + checked_pow(base, exp).ok_or(ArithmeticError::Overflow) + } + impl EnsureAdd for T {} impl EnsureSub for T {} impl EnsureMul for T {} @@ -953,6 +974,15 @@ mod tests { test_ensure(values(), &EnsureDiv::ensure_div, &CheckedDiv::checked_div); } + #[test] + fn ensure_pow_works() { + test_ensure( + values().into_iter().map(|(base, exp)| (base, exp as usize)).collect(), + ensure_pow, + |&a, &b| checked_pow(a, b), + ); + } + #[test] fn ensure_add_assign_works() { test_ensure_assign(values(), &EnsureAddAssign::ensure_add_assign, &EnsureAdd::ensure_add); @@ -974,11 +1004,12 @@ mod tests { } /// Test that the ensured function returns the expected un-ensured value. - fn test_ensure(pairs: Vec<(V, V)>, ensured: E, unensured: P) + fn test_ensure(pairs: Vec<(V, W)>, ensured: E, unensured: P) where V: Ensure + core::fmt::Debug + Copy, - E: Fn(V, V) -> Result, - P: Fn(&V, &V) -> Option, + W: Ensure + core::fmt::Debug + Copy, + E: Fn(V, W) -> Result, + P: Fn(&V, &W) -> Option, { for (a, b) in pairs.into_iter() { match ensured(a, b) { @@ -993,11 +1024,12 @@ mod tests { } /// Test that the ensured function modifies `self` to the expected un-ensured value. - fn test_ensure_assign(pairs: Vec<(V, V)>, ensured: E, unensured: P) + fn test_ensure_assign(pairs: Vec<(V, W)>, ensured: E, unensured: P) where V: Ensure + std::panic::RefUnwindSafe + std::panic::UnwindSafe + core::fmt::Debug + Copy, - E: Fn(&mut V, V) -> Result<(), ArithmeticError>, - P: Fn(V, V) -> Result + std::panic::RefUnwindSafe, + W: Ensure + std::panic::RefUnwindSafe + std::panic::UnwindSafe + core::fmt::Debug + Copy, + E: Fn(&mut V, W) -> Result<(), ArithmeticError>, + P: Fn(V, W) -> Result + std::panic::RefUnwindSafe, { for (mut a, b) in pairs.into_iter() { let old_a = a; diff --git a/primitives/runtime/src/traits.rs b/primitives/runtime/src/traits.rs index 8978cdb11c0c6..6eb19683b4439 100644 --- a/primitives/runtime/src/traits.rs +++ b/primitives/runtime/src/traits.rs @@ -32,11 +32,11 @@ use impl_trait_for_tuples::impl_for_tuples; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sp_application_crypto::AppKey; pub use sp_arithmetic::traits::{ - AtLeast32Bit, AtLeast32BitUnsigned, Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedShl, - CheckedShr, CheckedSub, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign, - EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp, - EnsureOpAssign, EnsureSub, EnsureSubAssign, IntegerSquareRoot, One, SaturatedConversion, - Saturating, UniqueSaturatedFrom, UniqueSaturatedInto, Zero, + checked_pow, ensure_pow, AtLeast32Bit, AtLeast32BitUnsigned, Bounded, CheckedAdd, CheckedDiv, + CheckedMul, CheckedShl, CheckedShr, CheckedSub, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, + EnsureDivAssign, EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, + EnsureOp, EnsureOpAssign, EnsureSub, EnsureSubAssign, IntegerSquareRoot, One, + SaturatedConversion, Saturating, UniqueSaturatedFrom, UniqueSaturatedInto, Zero, }; use sp_core::{self, storage::StateVersion, Hasher, RuntimeDebug, TypeId}; #[doc(hidden)]