Skip to content

Commit

Permalink
Allow multiple dsts in the hash2curve API
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpedda committed Feb 1, 2023
1 parent d69d5b9 commit 3744148
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 82 deletions.
15 changes: 9 additions & 6 deletions elliptic-curve/src/hash2curve/group_digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ where
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
fn hash_from_bytes<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dst: &'a [u8],
dsts: &'a [&'a [u8]],
) -> Result<ProjectivePoint<Self>> {
let mut u = [Self::FieldElement::default(), Self::FieldElement::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
let q0 = u[0].map_to_curve();
let q1 = u[1].map_to_curve();
// Ideally we could add and then clear cofactor once
Expand Down Expand Up @@ -88,10 +88,10 @@ where
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
fn encode_from_bytes<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dst: &'a [u8],
dsts: &'a [&'a [u8]],
) -> Result<ProjectivePoint<Self>> {
let mut u = [Self::FieldElement::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
let q0 = u[0].map_to_curve();
Ok(q0.clear_cofactor().into())
}
Expand All @@ -109,12 +109,15 @@ where
///
/// [`ExpandMsgXmd`]: crate::hash2curve::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(msgs: &[&[u8]], dst: &'a [u8]) -> Result<Self::Scalar>
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dsts: &'a [&'a [u8]],
) -> Result<Self::Scalar>
where
Self::Scalar: FromOkm,
{
let mut u = [Self::Scalar::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
Ok(u[0])
}
}
2 changes: 1 addition & 1 deletion elliptic-curve/src/hash2curve/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub trait FromOkm {
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
#[doc(hidden)]
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [u8], out: &mut [T]) -> Result<()>
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [&'a [u8]], out: &mut [T]) -> Result<()>
where
E: ExpandMsg<'a>,
T: FromOkm + Default,
Expand Down
74 changes: 52 additions & 22 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ pub trait ExpandMsg<'a> {
///
/// Returns an expander that can be used to call `read` until enough
/// bytes have been consumed
fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize)
-> Result<Self::Expander>;
fn expand_message(
msgs: &[&[u8]],
dsts: &'a [&'a [u8]],
len_in_bytes: usize,
) -> Result<Self::Expander>;
}

/// Expander that, call `read` until enough bytes have been consumed.
Expand All @@ -47,54 +50,66 @@ where
/// > 255
Hashed(GenericArray<u8, L>),
/// <= 255
Array(&'a [u8]),
Array(&'a [&'a [u8]]),
}

impl<'a, L> Domain<'a, L>
where
L: ArrayLength<u8> + IsLess<U256>,
{
pub fn xof<X>(dst: &'a [u8]) -> Result<Self>
pub fn xof<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
where
X: Default + ExtendableOutput + Update,
{
if dst.is_empty() {
if dsts.is_empty() {
Err(Error)
} else if dst.len() > MAX_DST_LEN {
} else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
let mut data = GenericArray::<u8, L>::default();
X::default()
.chain(OVERSIZE_DST_SALT)
.chain(dst)
.finalize_xof()
.read(&mut data);
let mut hash = X::default();
hash.update(OVERSIZE_DST_SALT);

for dst in dsts {
hash.update(dst);
}

hash.finalize_xof().read(&mut data);

Ok(Self::Hashed(data))
} else {
Ok(Self::Array(dst))
Ok(Self::Array(dsts))
}
}

pub fn xmd<X>(dst: &'a [u8]) -> Result<Self>
pub fn xmd<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
where
X: Digest<OutputSize = L>,
{
if dst.is_empty() {
if dsts.is_empty() {
Err(Error)
} else if dst.len() > MAX_DST_LEN {
} else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
Ok(Self::Hashed({
let mut hash = X::new();
hash.update(OVERSIZE_DST_SALT);
hash.update(dst);

for dst in dsts {
hash.update(dst);
}

hash.finalize()
}))
} else {
Ok(Self::Array(dst))
Ok(Self::Array(dsts))
}
}

pub fn data(&self) -> &[u8] {
pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
match self {
Self::Hashed(d) => &d[..],
Self::Array(d) => d,
Self::Hashed(d) => hash.update(d),
Self::Array(d) => {
for d in d.iter() {
hash.update(d)
}
}
}
}

Expand All @@ -103,13 +118,28 @@ where
// Can't overflow because it's enforced on a type level.
Self::Hashed(_) => L::to_u8(),
// Can't overflow because it's checked on creation.
Self::Array(d) => u8::try_from(d.len()).expect("length overflow"),
Self::Array(d) => {
u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
}
}
}

#[cfg(test)]
pub fn assert(&self, bytes: &[u8]) {
assert_eq!(self.data(), &bytes[..bytes.len() - 1]);
let data = match self {
Domain::Hashed(d) => d.to_vec(),
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
};
assert_eq!(data, bytes);
}

#[cfg(test)]
pub fn assert_dst(&self, bytes: &[u8]) {
let data = match self {
Domain::Hashed(d) => d.to_vec(),
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
};
assert_eq!(data, &bytes[..bytes.len() - 1]);
assert_eq!(self.len(), bytes[bytes.len() - 1]);
}
}
75 changes: 38 additions & 37 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use digest::{
typenum::{IsLess, IsLessOrEqual, Unsigned, U256},
GenericArray,
},
Digest,
FixedOutput, HashMarker,
};

/// Placeholder type for implementing `expand_message_xmd` based on a hash function
Expand All @@ -22,14 +22,14 @@ use digest::{
/// - `len_in_bytes > 255 * HashT::OutputSize`
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>;

/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait
impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd<HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
// If `len_in_bytes` is bigger then 256, length of the `DST` will depend on
// the output size of the hash, which is still not allowed to be bigger then 256:
// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6
Expand All @@ -42,7 +42,7 @@ where

fn expand_message(
msgs: &[&[u8]],
dst: &'a [u8],
dsts: &'a [&'a [u8]],
len_in_bytes: usize,
) -> Result<Self::Expander> {
if len_in_bytes == 0 {
Expand All @@ -54,26 +54,26 @@ where
let b_in_bytes = HashT::OutputSize::to_usize();
let ell = u8::try_from((len_in_bytes + b_in_bytes - 1) / b_in_bytes).map_err(|_| Error)?;

let domain = Domain::xmd::<HashT>(dst)?;
let mut b_0 = HashT::new();
b_0.update(GenericArray::<u8, HashT::BlockSize>::default());
let domain = Domain::xmd::<HashT>(dsts)?;
let mut b_0 = HashT::default();
b_0.update(&GenericArray::<u8, HashT::BlockSize>::default());

for msg in msgs {
b_0.update(msg);
}

b_0.update(len_in_bytes_u16.to_be_bytes());
b_0.update([0]);
b_0.update(domain.data());
b_0.update([domain.len()]);
let b_0 = b_0.finalize();
b_0.update(&len_in_bytes_u16.to_be_bytes());
b_0.update(&[0]);
domain.update_hash(&mut b_0);
b_0.update(&[domain.len()]);
let b_0 = b_0.finalize_fixed();

let mut b_vals = HashT::new();
let mut b_vals = HashT::default();
b_vals.update(&b_0[..]);
b_vals.update([1u8]);
b_vals.update(domain.data());
b_vals.update([domain.len()]);
let b_vals = b_vals.finalize();
b_vals.update(&[1u8]);
domain.update_hash(&mut b_vals);
b_vals.update(&[domain.len()]);
let b_vals = b_vals.finalize_fixed();

Ok(ExpanderXmd {
b_0,
Expand All @@ -89,7 +89,7 @@ where
/// [`Expander`] type for [`ExpandMsgXmd`].
pub struct ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
Expand All @@ -103,7 +103,7 @@ where

impl<'a, HashT> ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
Expand All @@ -118,12 +118,12 @@ where
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
let mut b_vals = HashT::new();
b_vals.update(tmp);
b_vals.update([self.index]);
b_vals.update(self.domain.data());
b_vals.update([self.domain.len()]);
self.b_vals = b_vals.finalize();
let mut b_vals = HashT::default();
b_vals.update(&tmp);
b_vals.update(&[self.index]);
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
true
} else {
false
Expand All @@ -133,7 +133,7 @@ where

impl<'a, HashT> Expander for ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
Expand Down Expand Up @@ -165,7 +165,7 @@ mod test {
len_in_bytes: u16,
bytes: &[u8],
) where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
{
let block = HashT::BlockSize::to_usize();
Expand All @@ -183,8 +183,8 @@ mod test {
let pad = l + mem::size_of::<u8>();
assert_eq!([0], &bytes[l..pad]);

let dst = pad + domain.data().len();
assert_eq!(domain.data(), &bytes[pad..dst]);
let dst = pad + usize::from(domain.len());
domain.assert(&bytes[pad..dst]);

let dst_len = dst + mem::size_of::<u8>();
assert_eq!([domain.len()], &bytes[dst..dst_len]);
Expand All @@ -205,13 +205,14 @@ mod test {
domain: &Domain<'_, HashT::OutputSize>,
) -> Result<()>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256> + IsLessOrEqual<HashT::BlockSize>,
{
assert_message::<HashT>(self.msg, domain, L::to_u16(), self.msg_prime);

let dst = [dst];
let mut expander =
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], dst, L::to_usize())?;
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], &dst, L::to_usize())?;

let mut uniform_bytes = GenericArray::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);
Expand All @@ -227,8 +228,8 @@ mod test {
const DST_PRIME: &[u8] =
&hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413235362d31323826");

let dst_prime = Domain::xmd::<Sha256>(DST)?;
dst_prime.assert(DST_PRIME);
let dst_prime = Domain::xmd::<Sha256>(&[DST])?;
dst_prime.assert_dst(DST_PRIME);

const TEST_VECTORS_32: &[TestVector] = &[
TestVector {
Expand Down Expand Up @@ -299,8 +300,8 @@ mod test {
const DST_PRIME: &[u8] =
&hex!("412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a23620");

let dst_prime = Domain::xmd::<Sha256>(DST)?;
dst_prime.assert(DST_PRIME);
let dst_prime = Domain::xmd::<Sha256>(&[DST])?;
dst_prime.assert_dst(DST_PRIME);

const TEST_VECTORS_32: &[TestVector] = &[
TestVector {
Expand Down Expand Up @@ -377,8 +378,8 @@ mod test {
const DST_PRIME: &[u8] =
&hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413531322d32353626");

let dst_prime = Domain::xmd::<Sha512>(DST)?;
dst_prime.assert(DST_PRIME);
let dst_prime = Domain::xmd::<Sha512>(&[DST])?;
dst_prime.assert_dst(DST_PRIME);

const TEST_VECTORS_32: &[TestVector] = &[
TestVector {
Expand Down
Loading

0 comments on commit 3744148

Please sign in to comment.