Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement << and <<= for all uints #1723

Merged
merged 6 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ and this project adheres to

## [Unreleased]

## Added

- cosmwasm-std: Add `<<` and `<<=` implementation for `Uint{64,128,256,512}`
types. ([#1723])

[#1723]: https://github.com/CosmWasm/cosmwasm/pull/1723

## [1.2.6] - 2023-06-05

## Changed
Expand Down
20 changes: 19 additions & 1 deletion packages/std/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod tests {
use super::*;
use std::ops::*;

/// An trait that ensures other traits are implemented for our number types
/// A trait that ensures other traits are implemented for our number types
trait AllImpl<'a>:
Add
+ Add<&'a Self>
Expand Down Expand Up @@ -50,10 +50,28 @@ mod tests {
{
}

/// A trait that ensures other traits are implemented for our integer types
trait IntImpl<'a>:
AllImpl<'a>
+ Shl<u32>
+ Shl<&'a u32>
+ ShlAssign<u32>
+ ShlAssign<&'a u32>
+ Shr<u32>
+ Shr<&'a u32>
+ ShrAssign<u32>
+ ShrAssign<&'a u32>
{
}

impl AllImpl<'_> for Uint64 {}
impl AllImpl<'_> for Uint128 {}
impl AllImpl<'_> for Uint256 {}
impl AllImpl<'_> for Uint512 {}
impl IntImpl<'_> for Uint64 {}
impl IntImpl<'_> for Uint128 {}
impl IntImpl<'_> for Uint256 {}
impl IntImpl<'_> for Uint512 {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove impl AllImpl<'_> for Uint512 {} and friends now? They should be required implicitly through the supertrait.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if we copy all the bounds from AllImpl to IntImpl. The way it is right now, IntImpl requires that they have AllImpl too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to line 67-70 in the new code. There we implement

    impl AllImpl<'_> for Uint64 {}
    impl AllImpl<'_> for Uint128 {}
    impl AllImpl<'_> for Uint256 {}
    impl AllImpl<'_> for Uint512 {}

which requires the Uint* types to implement the AllImpl traits. However, line 71-74 should imply that.

E.g. impl IntImpl<'_> for Uint64 {} failes to compile if Uint64 misses Add<&'a Self> since IntImpl requires AllImpl which requires Add<&'a Self>.

Or am I missing something here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know what you mean, but that's not how it works. Let's use Uint64 as an example: Line 71 implies that it implements AllImpl and all the Sh* traits. If we remove line 67, then it does not implement AllImpl anymore and line 71 will fail to compile. AllImpl does not get automatically implemented just because all the bounds are implemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right you are, amen

impl AllImpl<'_> for Decimal {}
impl AllImpl<'_> for Decimal256 {}
}
89 changes: 88 additions & 1 deletion packages/std/src/math/uint128.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::{self};
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shr, ShrAssign, Sub, SubAssign,
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign,
Sub, SubAssign,
};
use std::str::FromStr;

Expand Down Expand Up @@ -197,6 +198,22 @@ impl Uint128 {
.ok_or_else(|| DivideByZeroError::new(self))
}

pub fn checked_shr(self, other: u32) -> Result<Self, OverflowError> {
if other >= 128 {
return Err(OverflowError::new(OverflowOperation::Shr, self, other));
}

Ok(Self(self.0.shr(other)))
}

pub fn checked_shl(self, other: u32) -> Result<Self, OverflowError> {
if other >= 128 {
return Err(OverflowError::new(OverflowOperation::Shl, self, other));
}

Ok(Self(self.0.shl(other)))
}

#[must_use = "this returns the result of the operation, without modifying the original"]
#[inline]
pub fn wrapping_add(self, other: Self) -> Self {
Expand Down Expand Up @@ -441,6 +458,26 @@ impl<'a> Shr<&'a u32> for Uint128 {
}
}

impl Shl<u32> for Uint128 {
type Output = Self;

fn shl(self, rhs: u32) -> Self::Output {
Self(
self.u128()
.checked_shl(rhs)
.expect("attempt to shift left with overflow"),
)
}
}

impl<'a> Shl<&'a u32> for Uint128 {
type Output = Self;

fn shl(self, rhs: &'a u32) -> Self::Output {
self.shl(*rhs)
}
}

impl AddAssign<Uint128> for Uint128 {
fn add_assign(&mut self, rhs: Uint128) {
*self = *self + rhs;
Expand Down Expand Up @@ -497,6 +534,18 @@ impl<'a> ShrAssign<&'a u32> for Uint128 {
}
}

impl ShlAssign<u32> for Uint128 {
fn shl_assign(&mut self, rhs: u32) {
*self = Shl::<u32>::shl(*self, rhs);
}
}

impl<'a> ShlAssign<&'a u32> for Uint128 {
fn shl_assign(&mut self, rhs: &'a u32) {
*self = Shl::<u32>::shl(*self, *rhs);
}
}

impl Serialize for Uint128 {
/// Serializes as an integer string using base 10
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down Expand Up @@ -889,6 +938,44 @@ mod tests {
);
}

#[test]
fn uint128_shr_works() {
let original = Uint128::new(u128::from_be_bytes([
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 2u8, 0u8, 4u8, 2u8,
]));

let shifted = Uint128::new(u128::from_be_bytes([
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 128u8, 1u8, 0u8,
]));

assert_eq!(original >> 2u32, shifted);
}

#[test]
#[should_panic]
fn uint128_shr_overflow_panics() {
let _ = Uint128::from(1u32) >> 128u32;
}

#[test]
fn uint128_shl_works() {
let original = Uint128::new(u128::from_be_bytes([
64u8, 128u8, 1u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
]));

let shifted = Uint128::new(u128::from_be_bytes([
2u8, 0u8, 4u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
]));

assert_eq!(original << 2u32, shifted);
}

#[test]
#[should_panic]
fn uint128_shl_overflow_panics() {
let _ = Uint128::from(1u32) << 128u32;
}

#[test]
fn sum_works() {
let nums = vec![Uint128(17), Uint128(123), Uint128(540), Uint128(82)];
Expand Down
45 changes: 37 additions & 8 deletions packages/std/src/math/uint256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use schemars::JsonSchema;
use serde::{de, ser, Deserialize, Deserializer, Serialize};
use std::fmt;
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shl, Shr, ShrAssign, Sub,
SubAssign,
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign,
Sub, SubAssign,
};
use std::str::FromStr;

Expand Down Expand Up @@ -575,12 +575,8 @@ impl Shl<u32> for Uint256 {
type Output = Self;

fn shl(self, rhs: u32) -> Self::Output {
self.checked_shl(rhs).unwrap_or_else(|_| {
panic!(
"left shift error: {} is larger or equal than the number of bits in Uint256",
rhs,
)
})
self.checked_shl(rhs)
.expect("attempt to shift left with overflow")
}
}

Expand Down Expand Up @@ -628,6 +624,18 @@ impl<'a> ShrAssign<&'a u32> for Uint256 {
}
}

impl ShlAssign<u32> for Uint256 {
fn shl_assign(&mut self, rhs: u32) {
*self = self.shl(rhs);
}
}

impl<'a> ShlAssign<&'a u32> for Uint256 {
fn shl_assign(&mut self, rhs: &'a u32) {
*self = self.shl(*rhs);
}
}

impl Serialize for Uint256 {
/// Serializes as an integer string using base 10
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down Expand Up @@ -1496,6 +1504,27 @@ mod tests {
let _ = Uint256::from(1u32) >> 256u32;
}

#[test]
fn uint256_shl_works() {
let original = Uint256::new([
64u8, 128u8, 1u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
]);

let shifted = Uint256::new([
2u8, 0u8, 4u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
]);

assert_eq!(original << 2u32, shifted);
}

#[test]
#[should_panic]
fn uint256_shl_overflow_panics() {
let _ = Uint256::from(1u32) << 256u32;
}

#[test]
fn sum_works() {
let nums = vec![
Expand Down
65 changes: 64 additions & 1 deletion packages/std/src/math/uint512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use schemars::JsonSchema;
use serde::{de, ser, Deserialize, Deserializer, Serialize};
use std::fmt;
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shr, ShrAssign, Sub, SubAssign,
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign,
Sub, SubAssign,
};
use std::str::FromStr;

Expand Down Expand Up @@ -262,6 +263,14 @@ impl Uint512 {
Ok(Self(self.0.shr(other)))
}

pub fn checked_shl(self, other: u32) -> Result<Self, OverflowError> {
if other >= 512 {
return Err(OverflowError::new(OverflowOperation::Shl, self, other));
}

Ok(Self(self.0.shl(other)))
}

#[must_use = "this returns the result of the operation, without modifying the original"]
#[inline]
pub fn wrapping_add(self, other: Self) -> Self {
Expand Down Expand Up @@ -542,6 +551,23 @@ impl<'a> Shr<&'a u32> for Uint512 {
}
}

impl Shl<u32> for Uint512 {
type Output = Self;

fn shl(self, rhs: u32) -> Self::Output {
self.checked_shl(rhs)
.expect("attempt to shift left with overflow")
}
}

impl<'a> Shl<&'a u32> for Uint512 {
type Output = Self;

fn shl(self, rhs: &'a u32) -> Self::Output {
self.shl(*rhs)
}
}

impl AddAssign<Uint512> for Uint512 {
fn add_assign(&mut self, rhs: Uint512) {
self.0 = self.0.checked_add(rhs.0).unwrap();
Expand Down Expand Up @@ -578,6 +604,18 @@ impl<'a> ShrAssign<&'a u32> for Uint512 {
}
}

impl ShlAssign<u32> for Uint512 {
fn shl_assign(&mut self, rhs: u32) {
*self = self.shl(rhs);
}
}

impl<'a> ShlAssign<&'a u32> for Uint512 {
fn shl_assign(&mut self, rhs: &'a u32) {
*self = self.shl(*rhs);
}
}

impl Serialize for Uint512 {
/// Serializes as an integer string using base 10
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down Expand Up @@ -1119,6 +1157,31 @@ mod tests {
let _ = Uint512::from(1u32) >> 512u32;
}

#[test]
fn uint512_shl_works() {
let original = Uint512::new([
64u8, 128u8, 1u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
]);

let shifted = Uint512::new([
2u8, 0u8, 4u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
]);

assert_eq!(original << 2u32, shifted);
}

#[test]
#[should_panic]
fn uint512_shl_overflow_panics() {
let _ = Uint512::from(1u32) << 512u32;
}

#[test]
fn sum_works() {
let nums = vec![
Expand Down
Loading