From 7d270556245d5b1a0a94bb76fe26660cbafbd74b Mon Sep 17 00:00:00 2001 From: Joseph Spadavecchia Date: Fri, 18 Jun 2021 13:43:47 +0100 Subject: [PATCH] Simplified sponge interface --- dlog/plonk-5-wires/src/plonk_sponge.rs | 9 ++-- dlog/plonk/src/plonk_sponge.rs | 9 ++-- dlog/tests/turbo_plonk.rs | 10 ++--- dlog/tests/turbo_plonk_5_wires.rs | 5 +-- oracle/src/poseidon.rs | 59 +++++++++++++------------- oracle/src/sponge.rs | 20 ++++----- oracle/tests/poseidon_tests.rs | 18 ++++---- 7 files changed, 60 insertions(+), 70 deletions(-) diff --git a/dlog/plonk-5-wires/src/plonk_sponge.rs b/dlog/plonk-5-wires/src/plonk_sponge.rs index 1500f05512..8e03a6610b 100644 --- a/dlog/plonk-5-wires/src/plonk_sponge.rs +++ b/dlog/plonk-5-wires/src/plonk_sponge.rs @@ -15,15 +15,14 @@ pub trait FrSponge { impl FrSponge for DefaultFrSponge { fn new(params: ArithmeticSpongeParams) -> DefaultFrSponge { DefaultFrSponge { - params, - sponge: ArithmeticSponge::new(), + sponge: ArithmeticSponge::new(params), last_squeezed: vec![], } } fn absorb(&mut self, x: &Fr) { self.last_squeezed = vec![]; - self.sponge.absorb(&self.params, &[*x]); + self.sponge.absorb(&[*x]); } fn challenge(&mut self) -> ScalarChallenge { @@ -32,7 +31,7 @@ impl FrSponge for DefaultFrSponge { fn absorb_evaluations(&mut self, p: &[Fr], e: &ProofEvaluations>) { self.last_squeezed = vec![]; - self.sponge.absorb(&self.params, p); + self.sponge.absorb(p); let points = [ &e.w[0], &e.w[1], &e.w[2], &e.w[3], &e.w[4], &e.z, &e.t, &e.f, &e.s[0], &e.s[1], @@ -40,7 +39,7 @@ impl FrSponge for DefaultFrSponge { ]; for p in &points { - self.sponge.absorb(&self.params, p); + self.sponge.absorb(p); } } } diff --git a/dlog/plonk/src/plonk_sponge.rs b/dlog/plonk/src/plonk_sponge.rs index 7b438ba3dd..4fc60c7053 100644 --- a/dlog/plonk/src/plonk_sponge.rs +++ b/dlog/plonk/src/plonk_sponge.rs @@ -15,15 +15,14 @@ pub trait FrSponge { impl FrSponge for DefaultFrSponge { fn new(params: ArithmeticSpongeParams) -> DefaultFrSponge { DefaultFrSponge { - params, - sponge: ArithmeticSponge::new(), + sponge: ArithmeticSponge::new(params), last_squeezed: vec![], } } fn absorb(&mut self, x: &Fr) { self.last_squeezed = vec![]; - self.sponge.absorb(&self.params, &[*x]); + self.sponge.absorb(&[*x]); } fn challenge(&mut self) -> ScalarChallenge { @@ -32,12 +31,12 @@ impl FrSponge for DefaultFrSponge { fn absorb_evaluations(&mut self, p: &[Fr], e: &ProofEvaluations>) { self.last_squeezed = vec![]; - self.sponge.absorb(&self.params, p); + self.sponge.absorb(p); let points = [&e.l, &e.r, &e.o, &e.z, &e.f, &e.sigma1, &e.sigma2, &e.t]; for p in &points { - self.sponge.absorb(&self.params, p); + self.sponge.absorb(p); } } } diff --git a/dlog/tests/turbo_plonk.rs b/dlog/tests/turbo_plonk.rs index 437245eca6..aa86c91e96 100644 --- a/dlog/tests/turbo_plonk.rs +++ b/dlog/tests/turbo_plonk.rs @@ -528,8 +528,7 @@ where { let rng = &mut OsRng; - let params: ArithmeticSpongeParams = oracle::pasta::fp::params(); - let mut sponge = ArithmeticSponge::::new(); + let mut sponge = ArithmeticSponge::::new(oracle::pasta::fp::params()); let z = Fp::zero(); let mut batch = Vec::new(); @@ -629,7 +628,7 @@ where // HALF_ROUNDS_FULL full rounds constraint gates for j in 0..SC::ROUNDS_FULL { - sponge.full_round(j + 1, ¶ms); + sponge.full_round(j + 1); l.push(sponge.state[0]); r.push(sponge.state[1]); o.push(sponge.state[2]); @@ -823,8 +822,7 @@ where let s = (y2 - &y1) / &(x2 - &x1); - let mut sponge = ArithmeticSponge::::new(); - let params: ArithmeticSpongeParams = oracle::pasta::fp::params(); + let mut sponge = ArithmeticSponge::::new(oracle::pasta::fp::params()); sponge.state = vec![x1, x2, x3]; let z = Fp::zero(); @@ -882,7 +880,7 @@ where // ROUNDS_FULL full rounds constraint gates for j in 0..SC::ROUNDS_FULL { - sponge.full_round(j + 1, ¶ms); + sponge.full_round(j + 1); l.push(sponge.state[0]); r.push(sponge.state[1]); o.push(sponge.state[2]); diff --git a/dlog/tests/turbo_plonk_5_wires.rs b/dlog/tests/turbo_plonk_5_wires.rs index 4943923935..85209ea134 100644 --- a/dlog/tests/turbo_plonk_5_wires.rs +++ b/dlog/tests/turbo_plonk_5_wires.rs @@ -839,7 +839,6 @@ fn positive(index: &Index) { let rng = &mut OsRng; let mut batch = Vec::new(); let group_map = ::Map::setup(); - let params = oracle::pasta::fp5::params(); let lgr_comms: Vec<_> = (0..PUBLIC) .map(|i| { let mut v = vec![Fp::zero(); i + 1]; @@ -1016,7 +1015,7 @@ fn positive(index: &Index) { // witness for Poseidon permutation custom constraints - let mut sponge = ArithmeticSponge::::new(); + let mut sponge = ArithmeticSponge::::new(oracle::pasta::fp5::params()); sponge.state = vec![w(), w(), w(), w(), w()]; witness .iter_mut() @@ -1026,7 +1025,7 @@ fn positive(index: &Index) { // ROUNDS_FULL full rounds for j in 0..PlonkSpongeConstants5W::ROUNDS_FULL { - sponge.full_round(j, ¶ms); + sponge.full_round(j); witness .iter_mut() .zip(sponge.state.iter()) diff --git a/oracle/src/poseidon.rs b/oracle/src/poseidon.rs index 8d4d53c46c..070f14d29e 100644 --- a/oracle/src/poseidon.rs +++ b/oracle/src/poseidon.rs @@ -78,11 +78,10 @@ impl SpongeConstants for PlonkSpongeConstants3 { const INITIAL_ARK: bool = false; } -pub trait Sponge { - type Params; - fn new() -> Self; - fn absorb(&mut self, params: &Self::Params, x: &[Input]); - fn squeeze(&mut self, params: &Self::Params) -> Digest; +pub trait Sponge { + fn new(params: ArithmeticSpongeParams) -> Self; + fn absorb(&mut self, x: &[Input]); + fn squeeze(&mut self) -> Digest; } pub fn sbox(x: F) -> F { @@ -106,13 +105,14 @@ pub struct ArithmeticSponge { pub sponge_state: SpongeState, rate: usize, pub state: Vec, + params: ArithmeticSpongeParams, pub constants: std::marker::PhantomData, } impl ArithmeticSponge { - fn apply_mds_matrix(&mut self, params: &ArithmeticSpongeParams) { + fn apply_mds_matrix(&mut self) { self.state = if SC::FULL_MDS { - params + self.params .mds .iter() .map(|m| { @@ -131,40 +131,40 @@ impl ArithmeticSponge { }; } - pub fn full_round(&mut self, r: usize, params: &ArithmeticSpongeParams) { + pub fn full_round(&mut self, r: usize) { for i in 0..self.state.len() { self.state[i] = sbox::(self.state[i]); } - self.apply_mds_matrix(params); - for (i, x) in params.round_constants[r].iter().enumerate() { + self.apply_mds_matrix(); + for (i, x) in self.params.round_constants[r].iter().enumerate() { self.state[i].add_assign(x); } } - fn half_rounds(&mut self, params: &ArithmeticSpongeParams) { + fn half_rounds(&mut self) { for r in 0..SC::HALF_ROUNDS_FULL { - for (i, x) in params.round_constants[r].iter().enumerate() { + for (i, x) in self.params.round_constants[r].iter().enumerate() { self.state[i].add_assign(x); } for i in 0..self.state.len() { self.state[i] = sbox::(self.state[i]); } - self.apply_mds_matrix(params); + self.apply_mds_matrix(); } for r in 0..SC::ROUNDS_PARTIAL { - for (i, x) in params.round_constants[SC::HALF_ROUNDS_FULL + r] + for (i, x) in self.params.round_constants[SC::HALF_ROUNDS_FULL + r] .iter() .enumerate() { self.state[i].add_assign(x); } self.state[0] = sbox::(self.state[0]); - self.apply_mds_matrix(params); + self.apply_mds_matrix(); } for r in 0..SC::HALF_ROUNDS_FULL { - for (i, x) in params.round_constants[SC::HALF_ROUNDS_FULL + SC::ROUNDS_PARTIAL + r] + for (i, x) in self.params.round_constants[SC::HALF_ROUNDS_FULL + SC::ROUNDS_PARTIAL + r] .iter() .enumerate() { @@ -173,34 +173,32 @@ impl ArithmeticSponge { for i in 0..self.state.len() { self.state[i] = sbox::(self.state[i]); } - self.apply_mds_matrix(params); + self.apply_mds_matrix(); } } - fn poseidon_block_cipher(&mut self, params: &ArithmeticSpongeParams) { + fn poseidon_block_cipher(&mut self) { if SC::HALF_ROUNDS_FULL == 0 { if SC::INITIAL_ARK == true { - for (i, x) in params.round_constants[0].iter().enumerate() { + for (i, x) in self.params.round_constants[0].iter().enumerate() { self.state[i].add_assign(x); } for r in 0..SC::ROUNDS_FULL { - self.full_round(r + 1, params); + self.full_round(r + 1); } } else { for r in 0..SC::ROUNDS_FULL { - self.full_round(r, params); + self.full_round(r); } } } else { - self.half_rounds(params); + self.half_rounds(); } } } impl Sponge for ArithmeticSponge { - type Params = ArithmeticSpongeParams; - - fn new() -> ArithmeticSponge { + fn new(params: ArithmeticSpongeParams) -> ArithmeticSponge { let capacity = SC::SPONGE_CAPACITY; let rate = SC::SPONGE_RATE; @@ -214,16 +212,17 @@ impl Sponge for ArithmeticSponge { state, rate, sponge_state: SpongeState::Absorbed(0), + params, constants: std::marker::PhantomData, } } - fn absorb(&mut self, params: &ArithmeticSpongeParams, x: &[F]) { + fn absorb(&mut self, x: &[F]) { for x in x.iter() { match self.sponge_state { SpongeState::Absorbed(n) => { if n == self.rate { - self.poseidon_block_cipher(params); + self.poseidon_block_cipher(); self.sponge_state = SpongeState::Absorbed(1); self.state[0].add_assign(x); } else { @@ -239,11 +238,11 @@ impl Sponge for ArithmeticSponge { } } - fn squeeze(&mut self, params: &ArithmeticSpongeParams) -> F { + fn squeeze(&mut self) -> F { match self.sponge_state { SpongeState::Squeezed(n) => { if n == self.rate { - self.poseidon_block_cipher(params); + self.poseidon_block_cipher(); self.sponge_state = SpongeState::Squeezed(1); self.state[0] } else { @@ -252,7 +251,7 @@ impl Sponge for ArithmeticSponge { } } SpongeState::Absorbed(_n) => { - self.poseidon_block_cipher(params); + self.poseidon_block_cipher(); self.sponge_state = SpongeState::Squeezed(1); self.state[0] } diff --git a/oracle/src/sponge.rs b/oracle/src/sponge.rs index b4cbe9f081..6d4776bce6 100644 --- a/oracle/src/sponge.rs +++ b/oracle/src/sponge.rs @@ -62,13 +62,11 @@ impl ScalarChallenge { #[derive(Clone)] pub struct DefaultFqSponge { - pub params: ArithmeticSpongeParams, pub sponge: ArithmeticSponge, pub last_squeezed: Vec, } pub struct DefaultFrSponge { - pub params: ArithmeticSpongeParams, pub sponge: ArithmeticSponge, pub last_squeezed: Vec, } @@ -90,7 +88,7 @@ impl DefaultFrSponge { self.last_squeezed = remaining.to_vec(); Fr::from_repr(pack::(&limbs)) } else { - let x = self.sponge.squeeze(&self.params).into_repr(); + let x = self.sponge.squeeze().into_repr(); self.last_squeezed .extend(&x.as_ref()[0..HIGH_ENTROPY_LIMBS]); self.squeeze(num_limbs) @@ -110,7 +108,7 @@ where self.last_squeezed = remaining.to_vec(); limbs.to_vec() } else { - let x = self.sponge.squeeze(&self.params).into_repr(); + let x = self.sponge.squeeze().into_repr(); self.last_squeezed .extend(&x.as_ref()[0..HIGH_ENTROPY_LIMBS]); self.squeeze_limbs(num_limbs) @@ -119,7 +117,7 @@ where pub fn squeeze_field(&mut self) -> P::BaseField { self.last_squeezed = vec![]; - self.sponge.squeeze(&self.params) + self.sponge.squeeze() } pub fn squeeze(&mut self, num_limbs: usize) -> P::ScalarField { @@ -135,8 +133,7 @@ where { fn new(params: ArithmeticSpongeParams) -> DefaultFqSponge { DefaultFqSponge { - params, - sponge: ArithmeticSponge::new(), + sponge: ArithmeticSponge::new(params), last_squeezed: vec![], } } @@ -147,8 +144,8 @@ where if g.infinity { panic!("sponge got zero curve point"); } else { - self.sponge.absorb(&self.params, &[g.x]); - self.sponge.absorb(&self.params, &[g.y]); + self.sponge.absorb(&[g.x]); + self.sponge.absorb(&[g.y]); } } } @@ -172,7 +169,6 @@ where < ::Params::MODULUS.into() { self.sponge.absorb( - &self.params, &[P::BaseField::from_repr( ::BigInt::from_bits(&bits), )], @@ -188,8 +184,8 @@ where P::BaseField::zero() }; - self.sponge.absorb(&self.params, &[low_bits]); - self.sponge.absorb(&self.params, &[high_bit]); + self.sponge.absorb(&[low_bits]); + self.sponge.absorb(&[high_bit]); } }); } diff --git a/oracle/tests/poseidon_tests.rs b/oracle/tests/poseidon_tests.rs index 42765f6847..2ff1dc8f5b 100644 --- a/oracle/tests/poseidon_tests.rs +++ b/oracle/tests/poseidon_tests.rs @@ -27,9 +27,9 @@ mod tests { fn poseidon() { macro_rules! assert_poseidon_eq { ($input:expr, $target:expr) => { - let mut s = Poseidon::::new(); - s.absorb(&Parameters::params(), $input); - let output = s.squeeze(&Parameters::params()); + let mut s = Poseidon::::new(Parameters::params()); + s.absorb($input); + let output = s.squeeze(); assert_eq!( output, $target, @@ -174,9 +174,9 @@ mod tests { fn poseidon_5_wires() { macro_rules! assert_poseidon_5_wires_eq { ($input:expr, $target:expr) => { - let mut s = Poseidon::::new(); - s.absorb(&Parameters5W::params(), $input); - let output = s.squeeze(&Parameters5W::params()); + let mut s = Poseidon::::new(Parameters5W::params()); + s.absorb($input); + let output = s.squeeze(); assert_eq!( output, $target, @@ -369,9 +369,9 @@ mod tests { fn poseidon_3() { macro_rules! assert_poseidon_3_eq { ($input:expr, $target:expr) => { - let mut s = Poseidon::::new(); - s.absorb(&Parameters3::params(), $input); - let output = s.squeeze(&Parameters3::params()); + let mut s = Poseidon::::new(Parameters3::params()); + s.absorb($input); + let output = s.squeeze(); assert_eq!( output, $target,