Skip to content

Commit

Permalink
feat: port hamt validation logic to kamt (#1976)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stebalien authored Jan 29, 2024
1 parent 40d396e commit 0bbb3a0
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 86 deletions.
20 changes: 13 additions & 7 deletions ipld/kamt/src/bitfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use byteorder::{BigEndian, ByteOrder};
use fvm_ipld_encoding::de::{Deserialize, Deserializer};
use fvm_ipld_encoding::ser::{Serialize, Serializer};
use fvm_ipld_encoding::strict_bytes;
use serde::de::Error;

const MAX_LEN: usize = 4 * 8;

#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub struct Bitfield([u64; 4]);
Expand All @@ -17,7 +20,7 @@ impl Serialize for Bitfield {
where
S: Serializer,
{
let mut v = [0u8; 4 * 8];
let mut v = [0u8; MAX_LEN];
// Big endian ordering, to match go
BigEndian::write_u64(&mut v[..8], self.0[3]);
BigEndian::write_u64(&mut v[8..16], self.0[2]);
Expand All @@ -40,13 +43,16 @@ impl<'de> Deserialize<'de> for Bitfield {
D: Deserializer<'de>,
{
let mut res = Bitfield::zero();
let bytes = strict_bytes::ByteBuf::deserialize(deserializer)?.into_vec();

let mut arr = [0u8; 4 * 8];
let len = bytes.len();
for (old, new) in bytes.iter().zip(arr[(32 - len)..].iter_mut()) {
*new = *old;
let strict_bytes::ByteBuf(bytes) = Deserialize::deserialize(deserializer)?;
if bytes.len() > MAX_LEN {
return Err(Error::invalid_length(
bytes.len(),
&"bitfield length exceeds maximum",
));
}

let mut arr = [0u8; MAX_LEN];
arr[MAX_LEN - bytes.len()..].copy_from_slice(&bytes);
res.0[3] = BigEndian::read_u64(&arr[..8]);
res.0[2] = BigEndian::read_u64(&arr[8..16]);
res.0[1] = BigEndian::read_u64(&arr[16..24]);
Expand Down
46 changes: 17 additions & 29 deletions ipld/kamt/src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::iter::FusedIterator;

use fvm_ipld_blockstore::Blockstore;
use fvm_ipld_encoding::de::DeserializeOwned;
use fvm_ipld_encoding::CborStore;

use crate::hash_bits::HashBits;
use crate::node::{match_extension, ExtensionMatch, Node};
Expand All @@ -15,6 +14,7 @@ use crate::{AsHashedKey, Config, Error, KeyValuePair};
/// Iterator over a KAMT. Items are ordered by-key, ascending.
pub struct Iter<'a, BS, V, K, H, const N: usize = 32> {
store: &'a BS,
conf: &'a Config,
stack: Vec<std::slice::Iter<'a, Pointer<K, V, H, N>>>,
current: std::slice::Iter<'a, KeyValuePair<K, V>>,
}
Expand All @@ -25,8 +25,9 @@ where
V: DeserializeOwned,
BS: Blockstore,
{
pub(crate) fn new(store: &'a BS, root: &'a Node<K, V, H, N>) -> Self {
pub(crate) fn new(store: &'a BS, root: &'a Node<K, V, H, N>, conf: &'a Config) -> Self {
Self {
conf,
store,
stack: vec![root.pointers.iter()],
current: [].iter(),
Expand All @@ -40,7 +41,7 @@ where
conf: &'a Config,
) -> Result<Self, Error>
where
K: Borrow<Q>,
K: Borrow<Q> + PartialOrd,
Q: PartialEq,
H: AsHashedKey<Q, N>,
{
Expand All @@ -57,25 +58,17 @@ where
Some(p) => match p {
Pointer::Link {
cid, cache, ext, ..
} => {
if let Some(cached_node) = cache.get() {
(cached_node, ext)
} else {
let node =
if let Some(node) = store.get_cbor::<Node<K, V, H, N>>(cid)? {
node
} else {
return Err(Error::CidNotFound(cid.to_string()));
};

// Ignore error intentionally, the cache value will always be the same
(cache.get_or_init(|| Box::new(node)), ext)
}
}
} => (
cache.get_or_try_init(|| {
Node::load(conf, store, cid, stack.len() as u32).map(Box::new)
})?,
ext,
),
Pointer::Dirty { node, ext, .. } => (node, ext),
Pointer::Values(values) => {
return match values.iter().position(|kv| kv.key().borrow() == key) {
Some(offset) => Ok(Self {
conf,
store,
stack,
current: values[offset..].iter(),
Expand Down Expand Up @@ -113,17 +106,12 @@ where
};
match next {
Pointer::Link { cid, cache, .. } => {
let node = if let Some(cached_node) = cache.get() {
cached_node
} else {
let node = match self.store.get_cbor::<Node<K, V, H, N>>(cid) {
Ok(Some(node)) => node,
Ok(None) => return Some(Err(Error::CidNotFound(cid.to_string()))),
Err(err) => return Some(Err(err.into())),
};

// Ignore error intentionally, the cache value will always be the same
cache.get_or_init(|| Box::new(node))
let node = match cache.get_or_try_init(|| {
Node::load(self.conf, self.store, cid, self.stack.len() as u32)
.map(Box::new)
}) {
Ok(node) => node,
Err(e) => return Some(Err(e)),
};
self.stack.push(node.pointers.iter())
}
Expand Down
27 changes: 9 additions & 18 deletions ipld/kamt/src/kamt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<V: PartialEq, K: PartialEq, H, BS: Blockstore, const N: usize> PartialEq

impl<BS, K, V, H, const N: usize> Kamt<BS, K, V, H, N>
where
K: Serialize + DeserializeOwned,
K: Serialize + DeserializeOwned + PartialOrd,
V: Serialize + DeserializeOwned,
BS: Blockstore,
{
Expand All @@ -90,26 +90,17 @@ where

/// Lazily instantiate a Kamt from this root Cid with a specified parameters.
pub fn load_with_config(cid: &Cid, store: BS, conf: Config) -> Result<Self, Error> {
match store.get_cbor(cid)? {
Some(root) => Ok(Self {
root,
store,
conf,
flushed_cid: Some(*cid),
}),
None => Err(Error::CidNotFound(cid.to_string())),
}
Ok(Self {
root: Node::load(&conf, &store, cid, 0)?,
store,
conf,
flushed_cid: Some(*cid),
})
}

/// Sets the root based on the Cid of the root node using the Kamt store
pub fn set_root(&mut self, cid: &Cid) -> Result<(), Error> {
match self.store.get_cbor(cid)? {
Some(root) => {
self.root = root;
self.flushed_cid = Some(*cid);
}
None => return Err(Error::CidNotFound(cid.to_string())),
}
self.root = Node::load(&self.conf, &self.store, cid, 0)?;

Ok(())
}
Expand Down Expand Up @@ -393,7 +384,7 @@ where
/// assert_eq!(x,2)
/// ```
pub fn iter(&self) -> Iter<BS, V, K, H, N> {
Iter::new(&self.store, &self.root)
Iter::new(&self.store, &self.root, &self.conf)
}

/// Iterate over the KAMT starting at the given key.
Expand Down
127 changes: 98 additions & 29 deletions ipld/kamt/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use std::fmt::Debug;

use cid::Cid;
use fvm_ipld_blockstore::Blockstore;
use fvm_ipld_encoding::CborStore;
use fvm_ipld_encoding::{CborStore, DAG_CBOR};
use multihash::Code;
use once_cell::unsync::OnceCell;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Serialize, Serializer};

use super::bitfield::Bitfield;
use super::hash_bits::HashBits;
Expand Down Expand Up @@ -46,16 +46,98 @@ where
}
}

impl<'de, K, V, H, const N: usize> Deserialize<'de> for Node<K, V, H, N>
impl<K, V, H, const N: usize> Node<K, V, H, N>
where
K: DeserializeOwned,
K: PartialOrd + DeserializeOwned,
V: DeserializeOwned,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let (bitfield, pointers) = Deserialize::deserialize(deserializer)?;
pub fn load(
conf: &Config,
store: &impl Blockstore,
k: &Cid,
depth: u32,
) -> Result<Self, Error> {
let (bitfield, pointers): (Bitfield, Vec<Pointer<K, V, H, N>>) = store
.get_cbor(k)?
.ok_or_else(|| Error::CidNotFound(k.to_string()))?;

if pointers.len() > 1 << conf.bit_width {
return Err(Error::Dynamic(anyhow::anyhow!(
"number of pointers ({}) exceeds that allowed by the bitwidth ({})",
pointers.len(),
1 << conf.bit_width,
)));
}

if bitfield.count_ones() != pointers.len() {
return Err(Error::Dynamic(anyhow::anyhow!(
"number of pointers ({}) doesn't match bitfield ({})",
pointers.len(),
bitfield.count_ones(),
)));
}

// We only allow empty pointers at the root.
if pointers.is_empty() && depth != 0 {
return Err(Error::ZeroPointers);
}

for ptr in &pointers {
match ptr {
Pointer::Values(kvs) => {
if depth < conf.min_data_depth {
return Err(Error::Dynamic(anyhow::anyhow!(
"values not allowed below the minimum data depth ({} < {})",
depth,
conf.min_data_depth,
)));
}
if kvs.is_empty() {
return Err(Error::Dynamic(anyhow::anyhow!("empty HAMT bucket")));
}
if kvs.len() > conf.max_array_width {
return Err(Error::Dynamic(anyhow::anyhow!(
"too many items in bucket {} > {}",
kvs.len(),
conf.max_array_width,
)));
}
if !kvs.windows(2).all(|window| {
let [a, b] = window else { panic!("invalid window length") };
a.key() < b.key()
}) {
return Err(Error::Dynamic(anyhow::anyhow!(
"duplicate or unsorted keys in bucket"
)));
}
}
Pointer::Link { cid, ext, .. } => {
if cid.codec() != DAG_CBOR {
return Err(Error::Dynamic(anyhow::anyhow!(
"kamt nodes must be DagCBOR, not {}",
cid.codec()
)));
}
if ext.len() % conf.bit_width != 0 {
return Err(Error::Dynamic(anyhow::anyhow!(
"extension length {} is not a multiple of the bit-width {}",
ext.len(),
conf.bit_width,
)));
}
let remaining_bits = (N as u32 * u8::BITS) - (depth + 1) * conf.bit_width;
if remaining_bits <= ext.len() {
return Err(Error::Dynamic(anyhow::anyhow!(
"extension length must be less than {} bits, was {} bits",
remaining_bits,
ext.len(),
)));
}
}
Pointer::Dirty { .. } => panic!("fresh node can't be dirty"),
}
}

Ok(Node { bitfield, pointers })
}
}
Expand Down Expand Up @@ -183,6 +265,7 @@ where
self.get_value(
&mut HashBits::new(H::as_hashed_key(key).as_ref()),
conf,
0,
key,
store,
)
Expand All @@ -192,6 +275,7 @@ where
&self,
hashed_key: &mut HashBits,
conf: &Config,
depth: u32,
key: &Q,
store: &S,
) -> Result<Option<&V>, Error>
Expand All @@ -210,19 +294,8 @@ where

let (node, ext) = match child {
Pointer::Link { cid, cache, ext } => {
let node = if let Some(cached_node) = cache.get() {
// Link node is cached
cached_node
} else {
let node: Box<Node<K, V, H, N>> = if let Some(node) = store.get_cbor(cid)? {
node
} else {
return Err(Error::CidNotFound(cid.to_string()));
};
// Intentionally ignoring error, cache will always be the same.
cache.get_or_init(|| node)
};

let node = cache
.get_or_try_init(|| Node::load(conf, store, cid, depth + 1).map(Box::new))?;
(node, ext)
}
Pointer::Dirty { node, ext } => (node, ext),
Expand All @@ -235,7 +308,7 @@ where
};

match match_extension(conf, hashed_key, ext)? {
ExtensionMatch::Full { .. } => node.get_value(hashed_key, conf, key, store),
ExtensionMatch::Full { .. } => node.get_value(hashed_key, conf, depth + 1, key, store),
ExtensionMatch::Partial { .. } => Ok(None),
}
}
Expand Down Expand Up @@ -281,9 +354,7 @@ where
Pointer::Link { cid, cache, ext } => match match_extension(conf, hashed_key, ext)? {
ExtensionMatch::Full { skipped } => {
cache.get_or_try_init(|| {
store
.get_cbor(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))
Node::load(conf, store, cid, depth + 1).map(Box::new)
})?;
let child_node = cache.get_mut().expect("filled line above");

Expand Down Expand Up @@ -446,9 +517,7 @@ where
Pointer::Link { cid, cache, ext } => match match_extension(conf, hashed_key, ext)? {
ExtensionMatch::Full { skipped } => {
cache.get_or_try_init(|| {
store
.get_cbor(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))
Node::load(conf, store, cid, depth + 1).map(Box::new)
})?;
let child_node = cache.get_mut().expect("filled line above");

Expand Down
Loading

0 comments on commit 0bbb3a0

Please sign in to comment.