Skip to content

Commit

Permalink
Move to ZeroizeOnDrop (#54)
Browse files Browse the repository at this point in the history
* Move from `Zeroize` to `ZeroizeOnDrop`

* Pin pre-release dependencies
  • Loading branch information
daxpedda authored Jan 28, 2022
1 parent b59b359 commit 9eee936
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 42 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ std = ["alloc"]

[dependencies]
curve25519-dalek = { version = "3", default-features = false, optional = true }
derive-where = { version = "1.0.0-rc.1", features = ["zeroize"] }
derive-where = { version = "=1.0.0-rc.2", features = ["zeroize-on-drop"] }
digest = "0.10"
displaydoc = { version = "0.2", default-features = false }
elliptic-curve = { version = "0.12.0-pre.1", features = [
elliptic-curve = { version = "=0.12.0-pre.1", features = [
"hash2curve",
"sec1",
"voprf",
Expand All @@ -42,13 +42,13 @@ serde_ = { version = "1", package = "serde", default-features = false, features
], optional = true }
sha2 = { version = "0.10", default-features = false, optional = true }
subtle = { version = "2.3", default-features = false }
zeroize = { version = "1", default-features = false }
zeroize = { version = "1.5", default-features = false }

[dev-dependencies]
generic-array = { version = "0.14", features = ["more_lengths"] }
hex = "0.4"
json = "0.12"
p256 = { version = "0.11.0-pre.0", default-features = false, features = [
p256 = { version = "=0.11.0-pre.0", default-features = false, features = [
"hash2curve",
"voprf",
] }
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@
//!
//! [curve25519-dalek]: (https://doc.dalek.rs/curve25519_dalek/index.html#backends-and-features)
#![deny(unsafe_code)]
#![cfg_attr(not(test), deny(unsafe_code))]
#![no_std]
#![warn(
clippy::cargo,
Expand Down
58 changes: 21 additions & 37 deletions src/voprf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use alloc::vec::Vec;
use core::iter::{self, Map, Repeat, Zip};

use derive_where::DeriveWhere;
use derive_where::derive_where;
use digest::core_api::BlockSizeUser;
use digest::{Digest, Output, OutputSizeUser};
use generic_array::sequence::Concat;
Expand Down Expand Up @@ -64,8 +64,7 @@ impl Mode {

/// A client which engages with a [NonVerifiableServer] in base mode, meaning
/// that the OPRF outputs are not verifiable.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
Expand All @@ -83,8 +82,7 @@ where

/// A client which engages with a [VerifiableServer] in verifiable mode, meaning
/// that the OPRF outputs can be checked against a server public key.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar, <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
Expand All @@ -104,8 +102,7 @@ where

/// A server which engages with a [NonVerifiableClient] in base mode, meaning
/// that the OPRF outputs are not verifiable.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
Expand All @@ -123,8 +120,7 @@ where

/// A server which engages with a [VerifiableClient] in verifiable mode, meaning
/// that the OPRF outputs can be checked against a server public key.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar, <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
Expand All @@ -144,8 +140,7 @@ where

/// A proof produced by a [VerifiableServer] that the OPRF output matches
/// against a server public key.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
Expand All @@ -165,8 +160,7 @@ where

/// The first client message sent from a client (either verifiable or not) to a
/// server (either verifiable or not).
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
Expand All @@ -183,8 +177,7 @@ where

/// The server's response to the [BlindedElement] message from a client (either
/// verifiable or not) to a server (either verifiable or not).
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
Expand Down Expand Up @@ -768,7 +761,6 @@ where
/////////////////////////

/// Contains the fields that are returned by a non-verifiable client blind
#[derive(DeriveWhere)]
#[derive_where(Debug; <CS::Group as Group>::Scalar, <CS::Group as Group>::Elem)]
pub struct NonVerifiableClientBlindResult<CS: CipherSuite>
where
Expand All @@ -782,7 +774,6 @@ where
}

/// Contains the fields that are returned by a verifiable client blind
#[derive(DeriveWhere)]
#[derive_where(Debug; <CS::Group as Group>::Scalar, <CS::Group as Group>::Elem)]
pub struct VerifiableClientBlindResult<CS: CipherSuite>
where
Expand All @@ -804,7 +795,6 @@ pub type VerifiableClientBatchFinalizeResult<'a, C, I, II, IC, IM> = FinalizeAft
>;

/// Contains the fields that are returned by a verifiable server evaluate
#[derive(DeriveWhere)]
#[derive_where(Debug; <CS::Group as Group>::Scalar, <CS::Group as Group>::Elem)]
pub struct VerifiableServerEvaluateResult<CS: CipherSuite>
where
Expand All @@ -819,8 +809,7 @@ where

/// Contains prepared [`EvaluationElement`]s by a verifiable server batch
/// evaluate preparation.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
Expand All @@ -833,8 +822,7 @@ where
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;

/// Contains the prepared `t` by a verifiable server batch evaluate preparation.
#[derive(DeriveWhere)]
#[derive_where(Clone, Zeroize(drop))]
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
Expand Down Expand Up @@ -863,7 +851,6 @@ pub type VerifiableServerBatchEvaluatePreparedEvaluationElements<CS, I> = Map<

/// Contains the fields that are returned by a verifiable server batch evaluate
/// preparation.
#[derive(DeriveWhere)]
#[derive_where(Debug; I, <CS::Group as Group>::Scalar)]
pub struct VerifiableServerBatchEvaluatePrepareResult<
'a,
Expand All @@ -889,7 +876,6 @@ pub type VerifiableServerBatchEvaluateFinishedMessages<'a, CS, I> = Map<

/// Contains the fields that are returned by a verifiable server batch evaluate
/// finish.
#[derive(DeriveWhere)]
#[derive_where(Debug; <&'a I as core::iter::IntoIterator>::IntoIter, <CS::Group as Group>::Scalar)]
pub struct VerifiableServerBatchEvaluateFinishResult<'a, CS: 'a + CipherSuite, I>
where
Expand All @@ -904,7 +890,6 @@ where
}

/// Contains the fields that are returned by a verifiable server batch evaluate
#[derive(DeriveWhere)]
#[derive_where(Debug; <CS::Group as Group>::Scalar, <CS::Group as Group>::Elem)]
#[cfg(feature = "alloc")]
pub struct VerifiableServerBatchEvaluateResult<CS: CipherSuite>
Expand Down Expand Up @@ -1319,13 +1304,13 @@ where
#[cfg(test)]
mod tests {
use core::ops::Add;
use core::ptr;

use ::alloc::vec;
use ::alloc::vec::Vec;
use generic_array::typenum::Sum;
use generic_array::ArrayLength;
use rand::rngs::OsRng;
use zeroize::Zeroize;

use super::*;
use crate::Group;
Expand Down Expand Up @@ -1579,11 +1564,11 @@ mod tests {
let client_blind_result = NonVerifiableClient::<CS>::blind(input, &mut rng).unwrap();

let mut state = client_blind_result.state;
Zeroize::zeroize(&mut state);
unsafe { ptr::drop_in_place(&mut state) };
assert!(state.serialize().iter().all(|&x| x == 0));

let mut message = client_blind_result.message;
Zeroize::zeroize(&mut message);
unsafe { ptr::drop_in_place(&mut message) };
assert!(message.serialize().iter().all(|&x| x == 0));
}

Expand All @@ -1599,11 +1584,11 @@ mod tests {
let client_blind_result = VerifiableClient::<CS>::blind(input, &mut rng).unwrap();

let mut state = client_blind_result.state;
Zeroize::zeroize(&mut state);
unsafe { ptr::drop_in_place(&mut state) };
assert!(state.serialize().iter().all(|&x| x == 0));

let mut message = client_blind_result.message;
Zeroize::zeroize(&mut message);
unsafe { ptr::drop_in_place(&mut message) };
assert!(message.serialize().iter().all(|&x| x == 0));
}

Expand All @@ -1617,16 +1602,15 @@ mod tests {
let mut rng = OsRng;
let client_blind_result = NonVerifiableClient::<CS>::blind(input, &mut rng).unwrap();
let server = NonVerifiableServer::<CS>::new(&mut rng);
let message = server
let mut message = server
.evaluate(&client_blind_result.message, Some(info))
.unwrap();

let mut state = server;
Zeroize::zeroize(&mut state);
unsafe { ptr::drop_in_place(&mut state) };
assert!(state.serialize().iter().all(|&x| x == 0));

let mut message = message;
Zeroize::zeroize(&mut message);
unsafe { ptr::drop_in_place(&mut message) };
assert!(message.serialize().iter().all(|&x| x == 0));
}

Expand All @@ -1649,15 +1633,15 @@ mod tests {
.unwrap();

let mut state = server;
Zeroize::zeroize(&mut state);
unsafe { ptr::drop_in_place(&mut state) };
assert!(state.serialize().iter().all(|&x| x == 0));

let mut message = server_result.message;
Zeroize::zeroize(&mut message);
unsafe { ptr::drop_in_place(&mut message) };
assert!(message.serialize().iter().all(|&x| x == 0));

let mut proof = server_result.proof;
Zeroize::zeroize(&mut proof);
unsafe { ptr::drop_in_place(&mut proof) };
assert!(proof.serialize().iter().all(|&x| x == 0));
}

Expand Down

0 comments on commit 9eee936

Please sign in to comment.