Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Serde Serialize and Deserialize traits for the RbTree in the ic-certified-map crate. #399

Merged
merged 15 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions library/ic-certified-map/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- Implement CandidType, Serialize, and Deserialize for the RbTree.

## [0.4.0] - 2023-07-13

### Changed
Expand Down
3 changes: 2 additions & 1 deletion library/ic-certified-map/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ include = ["src", "Cargo.toml", "CHANGELOG.md", "LICENSE", "README.md"]
serde.workspace = true
serde_bytes.workspace = true
sha2.workspace = true
candid.workspace = true

[dev-dependencies]
hex.workspace = true
serde_cbor = "0.11"
ic-cdk.workspace = true
candid.workspace = true
bincode = "1.3.3"
166 changes: 126 additions & 40 deletions library/ic-certified-map/src/rbtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl AsHashTree for Hash {
}
}

impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> AsHashTree for RbTree<K, V> {
impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> AsHashTree for RbTree<K, V> {
fn root_hash(&self) -> Hash {
match self.root.as_ref() {
None => Empty.reconstruct(),
Expand Down Expand Up @@ -102,7 +102,7 @@ struct Node<K, V> {
subtree_hash: Hash,
}

impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> Node<K, V> {
impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> Node<K, V> {
fn new(key: K, value: V) -> Box<Node<K, V>> {
let value_hash = value.root_hash();
let data_hash = labeled_hash(key.as_ref(), &value_hash);
Expand Down Expand Up @@ -274,47 +274,47 @@ pub struct RbTree<K, V> {
root: NodeRef<K, V>,
}

impl<K, V> PartialEq for RbTree<K, V>
impl<'t, K, V> PartialEq for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + PartialEq,
V: 'static + AsHashTree + PartialEq,
K: 't + AsRef<[u8]> + PartialEq,
V: 't + AsHashTree + PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.iter().eq(other.iter())
}
}

impl<K, V> Eq for RbTree<K, V>
impl<'t, K, V> Eq for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + Eq,
V: 'static + AsHashTree + Eq,
K: 't + AsRef<[u8]> + Eq,
V: 't + AsHashTree + Eq,
{
}

impl<K, V> PartialOrd for RbTree<K, V>
impl<'t, K, V> PartialOrd for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + PartialOrd,
V: 'static + AsHashTree + PartialOrd,
K: 't + AsRef<[u8]> + PartialOrd,
V: 't + AsHashTree + PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.iter().partial_cmp(other.iter())
}
}

impl<K, V> Ord for RbTree<K, V>
impl<'t, K, V> Ord for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + Ord,
V: 'static + AsHashTree + Ord,
K: 't + AsRef<[u8]> + Ord,
V: 't + AsHashTree + Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
self.iter().cmp(other.iter())
}
}

impl<K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
impl<'t, K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
where
K: 'static + AsRef<[u8]>,
V: 'static + AsHashTree,
K: 't + AsRef<[u8]>,
V: 't + AsHashTree,
{
fn from_iter<T>(iter: T) -> Self
where
Expand All @@ -328,10 +328,10 @@ where
}
}

impl<K, V> std::fmt::Debug for RbTree<K, V>
impl<'t, K, V> std::fmt::Debug for RbTree<K, V>
where
K: 'static + AsRef<[u8]> + std::fmt::Debug,
V: 'static + AsHashTree + std::fmt::Debug,
K: 't + AsRef<[u8]> + std::fmt::Debug,
V: 't + AsHashTree + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
Expand Down Expand Up @@ -359,7 +359,7 @@ impl<K, V> RbTree<K, V> {
}
}

impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
impl<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't> RbTree<K, V> {
/// Looks up the key in the map and returns the associated value, if there is one.
pub fn get(&self, key: &[u8]) -> Option<&V> {
let mut root = self.root.as_ref();
Expand All @@ -375,7 +375,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {

/// Updates the value corresponding to the specified key.
pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) {
fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
h: &mut NodeRef<K, V>,
k: &[u8],
f: impl FnOnce(&mut V),
Expand Down Expand Up @@ -506,7 +506,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
lo: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
) -> HashTree<'a> {
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
lo: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
Expand Down Expand Up @@ -543,7 +543,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
hi: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
) -> HashTree<'a> {
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
hi: KeyBound<'a>,
f: fn(&'a Node<K, V>) -> HashTree<'a>,
Expand Down Expand Up @@ -587,7 +587,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
lo.as_ref(),
hi.as_ref()
);
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
lo: KeyBound<'a>,
hi: KeyBound<'a>,
Expand Down Expand Up @@ -645,7 +645,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}

fn lower_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
fn go<'a, K: 'static + AsRef<[u8]>, V>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
n: &'a NodeRef<K, V>,
key: &[u8],
) -> Option<KeyBound<'a>> {
Expand All @@ -662,7 +662,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}

fn upper_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
fn go<'a, K: 'static + AsRef<[u8]>, V>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
n: &'a NodeRef<K, V>,
key: &[u8],
) -> Option<KeyBound<'a>> {
Expand All @@ -685,7 +685,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}
&x[0..p.len()] == p
}
fn go<'a, K: 'static + AsRef<[u8]>, V>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
n: &'a NodeRef<K, V>,
prefix: &[u8],
) -> Option<KeyBound<'a>> {
Expand All @@ -706,7 +706,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
key: &[u8],
f: impl FnOnce(&'a V) -> HashTree<'a>,
) -> Option<HashTree<'a>> {
fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
n: &'a NodeRef<K, V>,
key: &[u8],
f: impl FnOnce(&'a V) -> HashTree<'a>,
Expand Down Expand Up @@ -740,7 +740,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {

/// Inserts a key-value entry into the map.
pub fn insert(&mut self, key: K, value: V) {
fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
h: NodeRef<K, V>,
k: K,
v: V,
Expand Down Expand Up @@ -778,7 +778,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {

/// Removes the specified key from the map.
pub fn delete(&mut self, key: &[u8]) {
fn move_red_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn move_red_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
flip_colors(&mut h);
Expand All @@ -790,7 +790,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
h
}

fn move_red_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn move_red_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
flip_colors(&mut h);
Expand All @@ -802,7 +802,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}

#[inline]
fn min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: &mut Box<Node<K, V>>,
) -> &mut Box<Node<K, V>> {
while h.left.is_some() {
Expand All @@ -811,7 +811,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
h
}

fn delete_min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn delete_min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> NodeRef<K, V> {
if h.left.is_none() {
Expand All @@ -827,7 +827,7 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
Some(balance(h))
}

fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
key: &[u8],
) -> NodeRef<K, V> {
Expand Down Expand Up @@ -888,6 +888,94 @@ impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
}
}

use candid::CandidType;

impl<'t, K, V> CandidType for RbTree<K, V>
where
K: CandidType + AsRef<[u8]> + 't,
V: CandidType + AsHashTree + 't,
{
fn _ty() -> candid::types::internal::Type {
<Vec<(&K, &V)> as CandidType>::_ty()
}
fn idl_serialize<S: candid::types::Serializer>(&self, serializer: S) -> Result<(), S::Error> {
let collect_as_vec = self.iter().collect::<Vec<(&K, &V)>>();
<Vec<(&K, &V)> as CandidType>::idl_serialize(&collect_as_vec, serializer)
}
}

use serde::{
de::{Deserialize, Deserializer, MapAccess, Visitor},
ser::{Serialize, SerializeMap, Serializer},
};
use std::marker::PhantomData;

impl<'t, K, V> Serialize for RbTree<K, V>
where
K: Serialize + AsRef<[u8]> + 't,
V: Serialize + AsHashTree + 't,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(self.iter().count()))?;
for (k, v) in self.iter() {
map.serialize_entry(k, v)?;
}
map.end()
}
}

// The PhantomData keeps the compiler from complaining about unused generic type parameters.
struct RbTreeSerdeVisitor<K, V> {
marker: PhantomData<fn() -> RbTree<K, V>>,
}

impl<K, V> RbTreeSerdeVisitor<K, V> {
fn new() -> Self {
RbTreeSerdeVisitor {
marker: PhantomData,
}
}
}

impl<'de, 't, K, V> Visitor<'de> for RbTreeSerdeVisitor<K, V>
where
K: Deserialize<'de> + AsRef<[u8]> + 't,
V: Deserialize<'de> + AsHashTree + 't,
{
type Value = RbTree<K, V>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a map")
}

fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut t = RbTree::<K, V>::new();
while let Some((key, value)) = access.next_entry()? {
t.insert(key, value);
}
Ok(t)
}
}

impl<'de, 't, K, V> Deserialize<'de> for RbTree<K, V>
where
K: Deserialize<'de> + AsRef<[u8]> + 't,
V: Deserialize<'de> + AsHashTree + 't,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(RbTreeSerdeVisitor::new())
}
}

fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> {
match (l, m, r) {
(Empty, m, Empty) => m,
Expand All @@ -906,9 +994,7 @@ fn is_red<K, V>(x: &NodeRef<K, V>) -> bool {
x.as_ref().map(|h| h.color == Color::Red).unwrap_or(false)
}

fn balance<K: AsRef<[u8]> + 'static, V: AsHashTree + 'static>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
fn balance<'t, K: AsRef<[u8]> + 't, V: AsHashTree + 't>(mut h: Box<Node<K, V>>) -> Box<Node<K, V>> {
if is_red(&h.right) && !is_red(&h.left) {
h = rotate_left(h);
}
Expand All @@ -922,7 +1008,7 @@ fn balance<K: AsRef<[u8]> + 'static, V: AsHashTree + 'static>(
}

/// Make a left-leaning link lean to the right.
fn rotate_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn rotate_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
debug_assert!(is_red(&h.left));
Expand All @@ -939,7 +1025,7 @@ fn rotate_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
x
}

fn rotate_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
fn rotate_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
mut h: Box<Node<K, V>>,
) -> Box<Node<K, V>> {
debug_assert!(is_red(&h.right));
Expand Down
Loading