Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

elliptic-curve: Allow multiple dsts in the hash2curve API #1238

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions elliptic-curve/src/hash2curve/group_digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ where
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
fn hash_from_bytes<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dst: &'a [u8],
dsts: &'a [&'a [u8]],
) -> Result<ProjectivePoint<Self>> {
let mut u = [Self::FieldElement::default(), Self::FieldElement::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
let q0 = u[0].map_to_curve();
let q1 = u[1].map_to_curve();
// Ideally we could add and then clear cofactor once
Expand Down Expand Up @@ -88,10 +88,10 @@ where
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
fn encode_from_bytes<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dst: &'a [u8],
dsts: &'a [&'a [u8]],
) -> Result<ProjectivePoint<Self>> {
let mut u = [Self::FieldElement::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
let q0 = u[0].map_to_curve();
Ok(q0.clear_cofactor().into())
}
Expand All @@ -109,12 +109,15 @@ where
///
/// [`ExpandMsgXmd`]: crate::hash2curve::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2curve::ExpandMsgXof
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(msgs: &[&[u8]], dst: &'a [u8]) -> Result<Self::Scalar>
fn hash_to_scalar<'a, X: ExpandMsg<'a>>(
msgs: &[&[u8]],
dsts: &'a [&'a [u8]],
) -> Result<Self::Scalar>
where
Self::Scalar: FromOkm,
{
let mut u = [Self::Scalar::default()];
hash_to_field::<X, _>(msgs, dst, &mut u)?;
hash_to_field::<X, _>(msgs, dsts, &mut u)?;
Ok(u[0])
}
}
2 changes: 1 addition & 1 deletion elliptic-curve/src/hash2curve/hash2field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub trait FromOkm {
/// [`ExpandMsgXmd`]: crate::hash2field::ExpandMsgXmd
/// [`ExpandMsgXof`]: crate::hash2field::ExpandMsgXof
#[doc(hidden)]
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [u8], out: &mut [T]) -> Result<()>
pub fn hash_to_field<'a, E, T>(data: &[&[u8]], domain: &'a [&'a [u8]], out: &mut [T]) -> Result<()>
where
E: ExpandMsg<'a>,
T: FromOkm + Default,
Expand Down
74 changes: 52 additions & 22 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ pub trait ExpandMsg<'a> {
///
/// Returns an expander that can be used to call `read` until enough
/// bytes have been consumed
fn expand_message(msgs: &[&[u8]], dst: &'a [u8], len_in_bytes: usize)
-> Result<Self::Expander>;
fn expand_message(
msgs: &[&[u8]],
dsts: &'a [&'a [u8]],
len_in_bytes: usize,
) -> Result<Self::Expander>;
}

/// Expander that, call `read` until enough bytes have been consumed.
Expand All @@ -47,54 +50,66 @@ where
/// > 255
Hashed(GenericArray<u8, L>),
/// <= 255
Array(&'a [u8]),
Array(&'a [&'a [u8]]),
}

impl<'a, L> Domain<'a, L>
where
L: ArrayLength<u8> + IsLess<U256>,
{
pub fn xof<X>(dst: &'a [u8]) -> Result<Self>
pub fn xof<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
where
X: Default + ExtendableOutput + Update,
{
if dst.is_empty() {
if dsts.is_empty() {
Err(Error)
} else if dst.len() > MAX_DST_LEN {
} else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
let mut data = GenericArray::<u8, L>::default();
X::default()
.chain(OVERSIZE_DST_SALT)
.chain(dst)
.finalize_xof()
.read(&mut data);
let mut hash = X::default();
hash.update(OVERSIZE_DST_SALT);

for dst in dsts {
hash.update(dst);
}

hash.finalize_xof().read(&mut data);

Ok(Self::Hashed(data))
} else {
Ok(Self::Array(dst))
Ok(Self::Array(dsts))
}
}

pub fn xmd<X>(dst: &'a [u8]) -> Result<Self>
pub fn xmd<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
where
X: Digest<OutputSize = L>,
{
if dst.is_empty() {
if dsts.is_empty() {
Err(Error)
} else if dst.len() > MAX_DST_LEN {
} else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
Ok(Self::Hashed({
let mut hash = X::new();
hash.update(OVERSIZE_DST_SALT);
hash.update(dst);

for dst in dsts {
hash.update(dst);
}

hash.finalize()
}))
} else {
Ok(Self::Array(dst))
Ok(Self::Array(dsts))
}
}

pub fn data(&self) -> &[u8] {
pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
match self {
Self::Hashed(d) => &d[..],
Self::Array(d) => d,
Self::Hashed(d) => hash.update(d),
Self::Array(d) => {
for d in d.iter() {
hash.update(d)
}
}
}
}

Expand All @@ -103,13 +118,28 @@ where
// Can't overflow because it's enforced on a type level.
Self::Hashed(_) => L::to_u8(),
// Can't overflow because it's checked on creation.
Self::Array(d) => u8::try_from(d.len()).expect("length overflow"),
Self::Array(d) => {
u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
}
}
}

#[cfg(test)]
pub fn assert(&self, bytes: &[u8]) {
assert_eq!(self.data(), &bytes[..bytes.len() - 1]);
let data = match self {
Domain::Hashed(d) => d.to_vec(),
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
};
assert_eq!(data, bytes);
}

#[cfg(test)]
pub fn assert_dst(&self, bytes: &[u8]) {
let data = match self {
Domain::Hashed(d) => d.to_vec(),
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
};
assert_eq!(data, &bytes[..bytes.len() - 1]);
assert_eq!(self.len(), bytes[bytes.len() - 1]);
}
}
75 changes: 38 additions & 37 deletions elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use digest::{
typenum::{IsLess, IsLessOrEqual, Unsigned, U256},
GenericArray,
},
Digest,
FixedOutput, HashMarker,
};

/// Placeholder type for implementing `expand_message_xmd` based on a hash function
Expand All @@ -22,14 +22,14 @@ use digest::{
/// - `len_in_bytes > 255 * HashT::OutputSize`
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>;

/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait
impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd<HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
// If `len_in_bytes` is bigger then 256, length of the `DST` will depend on
// the output size of the hash, which is still not allowed to be bigger then 256:
// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6
Expand All @@ -42,7 +42,7 @@ where

fn expand_message(
msgs: &[&[u8]],
dst: &'a [u8],
dsts: &'a [&'a [u8]],
len_in_bytes: usize,
) -> Result<Self::Expander> {
if len_in_bytes == 0 {
Expand All @@ -54,26 +54,26 @@ where
let b_in_bytes = HashT::OutputSize::to_usize();
let ell = u8::try_from((len_in_bytes + b_in_bytes - 1) / b_in_bytes).map_err(|_| Error)?;

let domain = Domain::xmd::<HashT>(dst)?;
let mut b_0 = HashT::new();
b_0.update(GenericArray::<u8, HashT::BlockSize>::default());
let domain = Domain::xmd::<HashT>(dsts)?;
let mut b_0 = HashT::default();
b_0.update(&GenericArray::<u8, HashT::BlockSize>::default());

for msg in msgs {
b_0.update(msg);
}

b_0.update(len_in_bytes_u16.to_be_bytes());
b_0.update([0]);
b_0.update(domain.data());
b_0.update([domain.len()]);
let b_0 = b_0.finalize();
b_0.update(&len_in_bytes_u16.to_be_bytes());
b_0.update(&[0]);
domain.update_hash(&mut b_0);
b_0.update(&[domain.len()]);
let b_0 = b_0.finalize_fixed();

let mut b_vals = HashT::new();
let mut b_vals = HashT::default();
b_vals.update(&b_0[..]);
b_vals.update([1u8]);
b_vals.update(domain.data());
b_vals.update([domain.len()]);
let b_vals = b_vals.finalize();
b_vals.update(&[1u8]);
domain.update_hash(&mut b_vals);
b_vals.update(&[domain.len()]);
let b_vals = b_vals.finalize_fixed();

Ok(ExpanderXmd {
b_0,
Expand All @@ -89,7 +89,7 @@ where
/// [`Expander`] type for [`ExpandMsgXmd`].
pub struct ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
Expand All @@ -103,7 +103,7 @@ where

impl<'a, HashT> ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
Expand All @@ -118,12 +118,12 @@ where
.zip(&self.b_vals[..])
.enumerate()
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
let mut b_vals = HashT::new();
b_vals.update(tmp);
b_vals.update([self.index]);
b_vals.update(self.domain.data());
b_vals.update([self.domain.len()]);
self.b_vals = b_vals.finalize();
let mut b_vals = HashT::default();
b_vals.update(&tmp);
b_vals.update(&[self.index]);
self.domain.update_hash(&mut b_vals);
b_vals.update(&[self.domain.len()]);
self.b_vals = b_vals.finalize_fixed();
true
} else {
false
Expand All @@ -133,7 +133,7 @@ where

impl<'a, HashT> Expander for ExpanderXmd<'a, HashT>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize>,
{
Expand Down Expand Up @@ -165,7 +165,7 @@ mod test {
len_in_bytes: u16,
bytes: &[u8],
) where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256>,
{
let block = HashT::BlockSize::to_usize();
Expand All @@ -183,8 +183,8 @@ mod test {
let pad = l + mem::size_of::<u8>();
assert_eq!([0], &bytes[l..pad]);

let dst = pad + domain.data().len();
assert_eq!(domain.data(), &bytes[pad..dst]);
let dst = pad + usize::from(domain.len());
domain.assert(&bytes[pad..dst]);

let dst_len = dst + mem::size_of::<u8>();
assert_eq!([domain.len()], &bytes[dst..dst_len]);
Expand All @@ -205,13 +205,14 @@ mod test {
domain: &Domain<'_, HashT::OutputSize>,
) -> Result<()>
where
HashT: Digest + BlockSizeUser,
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
HashT::OutputSize: IsLess<U256> + IsLessOrEqual<HashT::BlockSize>,
{
assert_message::<HashT>(self.msg, domain, L::to_u16(), self.msg_prime);

let dst = [dst];
let mut expander =
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], dst, L::to_usize())?;
ExpandMsgXmd::<HashT>::expand_message(&[self.msg], &dst, L::to_usize())?;

let mut uniform_bytes = GenericArray::<u8, L>::default();
expander.fill_bytes(&mut uniform_bytes);
Expand All @@ -227,8 +228,8 @@ mod test {
const DST_PRIME: &[u8] =
&hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413235362d31323826");

let dst_prime = Domain::xmd::<Sha256>(DST)?;
dst_prime.assert(DST_PRIME);
let dst_prime = Domain::xmd::<Sha256>(&[DST])?;
dst_prime.assert_dst(DST_PRIME);

const TEST_VECTORS_32: &[TestVector] = &[
TestVector {
Expand Down Expand Up @@ -299,8 +300,8 @@ mod test {
const DST_PRIME: &[u8] =
&hex!("412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a23620");

let dst_prime = Domain::xmd::<Sha256>(DST)?;
dst_prime.assert(DST_PRIME);
let dst_prime = Domain::xmd::<Sha256>(&[DST])?;
dst_prime.assert_dst(DST_PRIME);

const TEST_VECTORS_32: &[TestVector] = &[
TestVector {
Expand Down Expand Up @@ -377,8 +378,8 @@ mod test {
const DST_PRIME: &[u8] =
&hex!("515555582d5630312d435330322d776974682d657870616e6465722d5348413531322d32353626");

let dst_prime = Domain::xmd::<Sha512>(DST)?;
dst_prime.assert(DST_PRIME);
let dst_prime = Domain::xmd::<Sha512>(&[DST])?;
dst_prime.assert_dst(DST_PRIME);

const TEST_VECTORS_32: &[TestVector] = &[
TestVector {
Expand Down
Loading