Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix points/scalars deserialization via serde_json #143

Merged
merged 1 commit into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## v0.8.1
* Bugfix for points/scalars deserialization via serde_json [#143]

[#143]: https://github.com/ZenGo-X/curv/pull/143

## v0.8.0
* Implement Try and Increment when converting hash to scalar [#128] \
Improves performance and security of conversion 🔥
Expand Down
133 changes: 130 additions & 3 deletions src/elliptic/curves/wrappers/serde_support.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::fmt;
use std::marker::PhantomData;

use serde::de::{Error, MapAccess, Visitor};
use serde::de::{Error, IgnoredAny, MapAccess, SeqAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_bytes::Bytes;

use crate::elliptic::curves::{Curve, Point, Scalar};
use generic_array::GenericArray;
use typenum::Unsigned;

use crate::elliptic::curves::{Curve, ECPoint, ECScalar, Point, Scalar};

// ---
// --- Point (de)serialization
Expand Down Expand Up @@ -144,6 +147,48 @@ impl<'de, E: Curve> Deserialize<'de> for PointFromBytes<E> {
{
Point::from_bytes(v).map_err(|e| Err::custom(format!("invalid point: {}", e)))
}

// serde_json serializes bytes as a sequence of u8, so we need to support this format too
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let seq_len_hint = seq.size_hint();
let uncompressed_len = <E::Point as ECPoint>::UncompressedPointLength::USIZE;
let compressed_len = <E::Point as ECPoint>::CompressedPointLength::USIZE;

let mut buffer =
GenericArray::<u8, <E::Point as ECPoint>::UncompressedPointLength>::default();
let mut seq_len = 0;

for x in buffer.iter_mut() {
*x = match seq.next_element()? {
Some(b) => b,
None => break,
};
seq_len += 1;
}

if seq_len == uncompressed_len {
// Ensure that there are no other elements in the sequence
if seq.next_element::<IgnoredAny>()?.is_some() {
return Err(A::Error::invalid_length(
seq_len_hint.unwrap_or(seq_len + 1),
&format!("either {} or {} bytes", compressed_len, uncompressed_len)
.as_str(),
));
}
} else if seq_len != compressed_len {
return Err(A::Error::invalid_length(
seq_len_hint.unwrap_or(seq_len),
&format!("either {} or {} bytes", compressed_len, uncompressed_len)
.as_str(),
));
}

Point::from_bytes(&buffer.as_slice()[..seq_len])
.map_err(|e| A::Error::custom(format!("invalid point: {}", e)))
}
}

deserializer
Expand Down Expand Up @@ -242,6 +287,41 @@ impl<'de, E: Curve> Deserialize<'de> for ScalarFromBytes<E> {
{
Scalar::from_bytes(v).map_err(|_| Err::custom("invalid scalar"))
}

// serde_json serializes bytes as a sequence of u8, so we need to support this format too
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let seq_len_hint = seq.size_hint();
let expected_len = <E::Scalar as ECScalar>::ScalarLength::USIZE;

let mut buffer =
GenericArray::<u8, <E::Scalar as ECScalar>::ScalarLength>::default();

for (i, x) in buffer.iter_mut().enumerate() {
*x = match seq.next_element()? {
Some(b) => b,
None => {
return Err(A::Error::invalid_length(
i,
&format!("{} bytes", expected_len).as_str(),
))
}
};
}

// Ensure that there are no other elements in the sequence
if seq.next_element::<IgnoredAny>()?.is_some() {
return Err(A::Error::invalid_length(
seq_len_hint.unwrap_or(expected_len + 1),
&format!("{} bytes", expected_len).as_str(),
));
}

Scalar::from_bytes(buffer.as_slice())
.map_err(|_| A::Error::custom("invalid scalar"))
}
}

deserializer
Expand All @@ -259,7 +339,7 @@ enum ScalarField {

#[cfg(test)]
mod serde_tests {
use serde_test::{assert_de_tokens_error, assert_tokens, Token::*};
use serde_test::{assert_de_tokens, assert_de_tokens_error, assert_tokens, Token::*};

use crate::elliptic::curves::*;
use crate::test_for_all_curves;
Expand Down Expand Up @@ -305,6 +385,53 @@ mod serde_tests {
}
}

test_for_all_curves!(test_deserialize_point_from_seq_of_bytes);
fn test_deserialize_point_from_seq_of_bytes<E: Curve>() {
let random_point = Point::<E>::generator() * Scalar::random();
for point in [Point::zero(), random_point] {
println!("Point: {:?}", point);
let bytes = point.to_bytes(true);
let mut tokens = vec![
Struct {
name: "Point",
len: 2,
},
Str("curve"),
Str(E::CURVE_NAME),
Str("point"),
Seq {
len: Option::Some(bytes.len()),
},
];
tokens.extend(bytes.iter().copied().map(U8));
tokens.extend_from_slice(&[SeqEnd, StructEnd]);
assert_de_tokens(&point, &tokens);
}
}

test_for_all_curves!(test_deserialize_scalar_from_seq_of_bytes);
fn test_deserialize_scalar_from_seq_of_bytes<E: Curve>() {
for scalar in [Scalar::<E>::zero(), Scalar::random()] {
println!("Scalar: {:?}", scalar);
let bytes = scalar.to_bytes();
let mut tokens = vec![
Struct {
name: "Scalar",
len: 2,
},
Str("curve"),
Str(E::CURVE_NAME),
Str("scalar"),
Seq {
len: Option::Some(bytes.len()),
},
];
tokens.extend(bytes.iter().copied().map(U8));
tokens.extend_from_slice(&[SeqEnd, StructEnd]);
assert_de_tokens(&scalar, &tokens);
}
}

test_for_all_curves!(doesnt_deserialize_point_from_different_curve);
fn doesnt_deserialize_point_from_different_curve<E: Curve>() {
let tokens = vec![
Expand Down