Skip to content

Commit

Permalink
Merge pull request #117 from o1-labs/jspada/sponge-improvements
Browse files Browse the repository at this point in the history
Simplified sponge interface
  • Loading branch information
jspada authored Jun 18, 2021
2 parents a720d53 + 7d27055 commit d89ad19
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 70 deletions.
9 changes: 4 additions & 5 deletions dlog/plonk-5-wires/src/plonk_sponge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ pub trait FrSponge<Fr: Field> {
impl<Fr: PrimeField> FrSponge<Fr> for DefaultFrSponge<Fr, SC> {
fn new(params: ArithmeticSpongeParams<Fr>) -> DefaultFrSponge<Fr, SC> {
DefaultFrSponge {
params,
sponge: ArithmeticSponge::new(),
sponge: ArithmeticSponge::new(params),
last_squeezed: vec![],
}
}

fn absorb(&mut self, x: &Fr) {
self.last_squeezed = vec![];
self.sponge.absorb(&self.params, &[*x]);
self.sponge.absorb(&[*x]);
}

fn challenge(&mut self) -> ScalarChallenge<Fr> {
Expand All @@ -32,15 +31,15 @@ impl<Fr: PrimeField> FrSponge<Fr> for DefaultFrSponge<Fr, SC> {

fn absorb_evaluations(&mut self, p: &[Fr], e: &ProofEvaluations<Vec<Fr>>) {
self.last_squeezed = vec![];
self.sponge.absorb(&self.params, p);
self.sponge.absorb(p);

let points = [
&e.w[0], &e.w[1], &e.w[2], &e.w[3], &e.w[4], &e.z, &e.t, &e.f, &e.s[0], &e.s[1],
&e.s[2], &e.s[3],
];

for p in &points {
self.sponge.absorb(&self.params, p);
self.sponge.absorb(p);
}
}
}
9 changes: 4 additions & 5 deletions dlog/plonk/src/plonk_sponge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ pub trait FrSponge<Fr: Field> {
impl<Fr: PrimeField> FrSponge<Fr> for DefaultFrSponge<Fr, SC> {
fn new(params: ArithmeticSpongeParams<Fr>) -> DefaultFrSponge<Fr, SC> {
DefaultFrSponge {
params,
sponge: ArithmeticSponge::new(),
sponge: ArithmeticSponge::new(params),
last_squeezed: vec![],
}
}

fn absorb(&mut self, x: &Fr) {
self.last_squeezed = vec![];
self.sponge.absorb(&self.params, &[*x]);
self.sponge.absorb(&[*x]);
}

fn challenge(&mut self) -> ScalarChallenge<Fr> {
Expand All @@ -32,12 +31,12 @@ impl<Fr: PrimeField> FrSponge<Fr> for DefaultFrSponge<Fr, SC> {

fn absorb_evaluations(&mut self, p: &[Fr], e: &ProofEvaluations<Vec<Fr>>) {
self.last_squeezed = vec![];
self.sponge.absorb(&self.params, p);
self.sponge.absorb(p);

let points = [&e.l, &e.r, &e.o, &e.z, &e.f, &e.sigma1, &e.sigma2, &e.t];

for p in &points {
self.sponge.absorb(&self.params, p);
self.sponge.absorb(p);
}
}
}
10 changes: 4 additions & 6 deletions dlog/tests/turbo_plonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,7 @@ where
{
let rng = &mut OsRng;

let params: ArithmeticSpongeParams<Fp> = oracle::pasta::fp::params();
let mut sponge = ArithmeticSponge::<Fp, SC>::new();
let mut sponge = ArithmeticSponge::<Fp, SC>::new(oracle::pasta::fp::params());

let z = Fp::zero();
let mut batch = Vec::new();
Expand Down Expand Up @@ -629,7 +628,7 @@ where

// HALF_ROUNDS_FULL full rounds constraint gates
for j in 0..SC::ROUNDS_FULL {
sponge.full_round(j + 1, &params);
sponge.full_round(j + 1);
l.push(sponge.state[0]);
r.push(sponge.state[1]);
o.push(sponge.state[2]);
Expand Down Expand Up @@ -823,8 +822,7 @@ where

let s = (y2 - &y1) / &(x2 - &x1);

let mut sponge = ArithmeticSponge::<Fp, SC>::new();
let params: ArithmeticSpongeParams<Fp> = oracle::pasta::fp::params();
let mut sponge = ArithmeticSponge::<Fp, SC>::new(oracle::pasta::fp::params());
sponge.state = vec![x1, x2, x3];
let z = Fp::zero();

Expand Down Expand Up @@ -882,7 +880,7 @@ where

// ROUNDS_FULL full rounds constraint gates
for j in 0..SC::ROUNDS_FULL {
sponge.full_round(j + 1, &params);
sponge.full_round(j + 1);
l.push(sponge.state[0]);
r.push(sponge.state[1]);
o.push(sponge.state[2]);
Expand Down
5 changes: 2 additions & 3 deletions dlog/tests/turbo_plonk_5_wires.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,6 @@ fn positive(index: &Index<Affine>) {
let rng = &mut OsRng;
let mut batch = Vec::new();
let group_map = <Affine as CommitmentCurve>::Map::setup();
let params = oracle::pasta::fp5::params();
let lgr_comms: Vec<_> = (0..PUBLIC)
.map(|i| {
let mut v = vec![Fp::zero(); i + 1];
Expand Down Expand Up @@ -1016,7 +1015,7 @@ fn positive(index: &Index<Affine>) {

// witness for Poseidon permutation custom constraints

let mut sponge = ArithmeticSponge::<Fp, PlonkSpongeConstants5W>::new();
let mut sponge = ArithmeticSponge::<Fp, PlonkSpongeConstants5W>::new(oracle::pasta::fp5::params());
sponge.state = vec![w(), w(), w(), w(), w()];
witness
.iter_mut()
Expand All @@ -1026,7 +1025,7 @@ fn positive(index: &Index<Affine>) {
// ROUNDS_FULL full rounds

for j in 0..PlonkSpongeConstants5W::ROUNDS_FULL {
sponge.full_round(j, &params);
sponge.full_round(j);
witness
.iter_mut()
.zip(sponge.state.iter())
Expand Down
59 changes: 29 additions & 30 deletions oracle/src/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ impl SpongeConstants for PlonkSpongeConstants3 {
const INITIAL_ARK: bool = false;
}

pub trait Sponge<Input, Digest> {
type Params;
fn new() -> Self;
fn absorb(&mut self, params: &Self::Params, x: &[Input]);
fn squeeze(&mut self, params: &Self::Params) -> Digest;
pub trait Sponge<Input: Field, Digest> {
fn new(params: ArithmeticSpongeParams<Input>) -> Self;
fn absorb(&mut self, x: &[Input]);
fn squeeze(&mut self) -> Digest;
}

pub fn sbox<F: Field, SC: SpongeConstants>(x: F) -> F {
Expand All @@ -106,13 +105,14 @@ pub struct ArithmeticSponge<F: Field, SC: SpongeConstants> {
pub sponge_state: SpongeState,
rate: usize,
pub state: Vec<F>,
params: ArithmeticSpongeParams<F>,
pub constants: std::marker::PhantomData<SC>,
}

impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
fn apply_mds_matrix(&mut self, params: &ArithmeticSpongeParams<F>) {
fn apply_mds_matrix(&mut self) {
self.state = if SC::FULL_MDS {
params
self.params
.mds
.iter()
.map(|m| {
Expand All @@ -131,40 +131,40 @@ impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
};
}

pub fn full_round(&mut self, r: usize, params: &ArithmeticSpongeParams<F>) {
pub fn full_round(&mut self, r: usize) {
for i in 0..self.state.len() {
self.state[i] = sbox::<F, SC>(self.state[i]);
}
self.apply_mds_matrix(params);
for (i, x) in params.round_constants[r].iter().enumerate() {
self.apply_mds_matrix();
for (i, x) in self.params.round_constants[r].iter().enumerate() {
self.state[i].add_assign(x);
}
}

fn half_rounds(&mut self, params: &ArithmeticSpongeParams<F>) {
fn half_rounds(&mut self) {
for r in 0..SC::HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[r].iter().enumerate() {
for (i, x) in self.params.round_constants[r].iter().enumerate() {
self.state[i].add_assign(x);
}
for i in 0..self.state.len() {
self.state[i] = sbox::<F, SC>(self.state[i]);
}
self.apply_mds_matrix(params);
self.apply_mds_matrix();
}

for r in 0..SC::ROUNDS_PARTIAL {
for (i, x) in params.round_constants[SC::HALF_ROUNDS_FULL + r]
for (i, x) in self.params.round_constants[SC::HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
self.state[i].add_assign(x);
}
self.state[0] = sbox::<F, SC>(self.state[0]);
self.apply_mds_matrix(params);
self.apply_mds_matrix();
}

for r in 0..SC::HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[SC::HALF_ROUNDS_FULL + SC::ROUNDS_PARTIAL + r]
for (i, x) in self.params.round_constants[SC::HALF_ROUNDS_FULL + SC::ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
Expand All @@ -173,34 +173,32 @@ impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
for i in 0..self.state.len() {
self.state[i] = sbox::<F, SC>(self.state[i]);
}
self.apply_mds_matrix(params);
self.apply_mds_matrix();
}
}

fn poseidon_block_cipher(&mut self, params: &ArithmeticSpongeParams<F>) {
fn poseidon_block_cipher(&mut self) {
if SC::HALF_ROUNDS_FULL == 0 {
if SC::INITIAL_ARK == true {
for (i, x) in params.round_constants[0].iter().enumerate() {
for (i, x) in self.params.round_constants[0].iter().enumerate() {
self.state[i].add_assign(x);
}
for r in 0..SC::ROUNDS_FULL {
self.full_round(r + 1, params);
self.full_round(r + 1);
}
} else {
for r in 0..SC::ROUNDS_FULL {
self.full_round(r, params);
self.full_round(r);
}
}
} else {
self.half_rounds(params);
self.half_rounds();
}
}
}

impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
type Params = ArithmeticSpongeParams<F>;

fn new() -> ArithmeticSponge<F, SC> {
fn new(params: ArithmeticSpongeParams<F>) -> ArithmeticSponge<F, SC> {
let capacity = SC::SPONGE_CAPACITY;
let rate = SC::SPONGE_RATE;

Expand All @@ -214,16 +212,17 @@ impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
state,
rate,
sponge_state: SpongeState::Absorbed(0),
params,
constants: std::marker::PhantomData,
}
}

fn absorb(&mut self, params: &ArithmeticSpongeParams<F>, x: &[F]) {
fn absorb(&mut self, x: &[F]) {
for x in x.iter() {
match self.sponge_state {
SpongeState::Absorbed(n) => {
if n == self.rate {
self.poseidon_block_cipher(params);
self.poseidon_block_cipher();
self.sponge_state = SpongeState::Absorbed(1);
self.state[0].add_assign(x);
} else {
Expand All @@ -239,11 +238,11 @@ impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
}
}

fn squeeze(&mut self, params: &ArithmeticSpongeParams<F>) -> F {
fn squeeze(&mut self) -> F {
match self.sponge_state {
SpongeState::Squeezed(n) => {
if n == self.rate {
self.poseidon_block_cipher(params);
self.poseidon_block_cipher();
self.sponge_state = SpongeState::Squeezed(1);
self.state[0]
} else {
Expand All @@ -252,7 +251,7 @@ impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
}
}
SpongeState::Absorbed(_n) => {
self.poseidon_block_cipher(params);
self.poseidon_block_cipher();
self.sponge_state = SpongeState::Squeezed(1);
self.state[0]
}
Expand Down
20 changes: 8 additions & 12 deletions oracle/src/sponge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ impl<F: PrimeField> ScalarChallenge<F> {

#[derive(Clone)]
pub struct DefaultFqSponge<P: SWModelParameters, SC: SpongeConstants> {
pub params: ArithmeticSpongeParams<P::BaseField>,
pub sponge: ArithmeticSponge<P::BaseField, SC>,
pub last_squeezed: Vec<u64>,
}

pub struct DefaultFrSponge<Fr: Field, SC: SpongeConstants> {
pub params: ArithmeticSpongeParams<Fr>,
pub sponge: ArithmeticSponge<Fr, SC>,
pub last_squeezed: Vec<u64>,
}
Expand All @@ -90,7 +88,7 @@ impl<Fr: PrimeField, SC: SpongeConstants> DefaultFrSponge<Fr, SC> {
self.last_squeezed = remaining.to_vec();
Fr::from_repr(pack::<Fr::BigInt>(&limbs))
} else {
let x = self.sponge.squeeze(&self.params).into_repr();
let x = self.sponge.squeeze().into_repr();
self.last_squeezed
.extend(&x.as_ref()[0..HIGH_ENTROPY_LIMBS]);
self.squeeze(num_limbs)
Expand All @@ -110,7 +108,7 @@ where
self.last_squeezed = remaining.to_vec();
limbs.to_vec()
} else {
let x = self.sponge.squeeze(&self.params).into_repr();
let x = self.sponge.squeeze().into_repr();
self.last_squeezed
.extend(&x.as_ref()[0..HIGH_ENTROPY_LIMBS]);
self.squeeze_limbs(num_limbs)
Expand All @@ -119,7 +117,7 @@ where

pub fn squeeze_field(&mut self) -> P::BaseField {
self.last_squeezed = vec![];
self.sponge.squeeze(&self.params)
self.sponge.squeeze()
}

pub fn squeeze(&mut self, num_limbs: usize) -> P::ScalarField {
Expand All @@ -135,8 +133,7 @@ where
{
fn new(params: ArithmeticSpongeParams<P::BaseField>) -> DefaultFqSponge<P, SC> {
DefaultFqSponge {
params,
sponge: ArithmeticSponge::new(),
sponge: ArithmeticSponge::new(params),
last_squeezed: vec![],
}
}
Expand All @@ -147,8 +144,8 @@ where
if g.infinity {
panic!("sponge got zero curve point");
} else {
self.sponge.absorb(&self.params, &[g.x]);
self.sponge.absorb(&self.params, &[g.y]);
self.sponge.absorb(&[g.x]);
self.sponge.absorb(&[g.y]);
}
}
}
Expand All @@ -172,7 +169,6 @@ where
< <P::BaseField as PrimeField>::Params::MODULUS.into()
{
self.sponge.absorb(
&self.params,
&[P::BaseField::from_repr(
<P::BaseField as PrimeField>::BigInt::from_bits(&bits),
)],
Expand All @@ -188,8 +184,8 @@ where
P::BaseField::zero()
};

self.sponge.absorb(&self.params, &[low_bits]);
self.sponge.absorb(&self.params, &[high_bit]);
self.sponge.absorb(&[low_bits]);
self.sponge.absorb(&[high_bit]);
}
});
}
Expand Down
Loading

0 comments on commit d89ad19

Please sign in to comment.