Skip to content

Commit

Permalink
expose block cipher function from poseidon
Browse files Browse the repository at this point in the history
  • Loading branch information
Izaak Meckler committed Sep 14, 2021
1 parent 6b93db7 commit bacef43
Showing 1 changed file with 84 additions and 69 deletions.
153 changes: 84 additions & 69 deletions oracle/src/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,91 +109,106 @@ pub struct ArithmeticSponge<F: Field, SC: SpongeConstants> {
pub constants: std::marker::PhantomData<SC>,
}

impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
fn apply_mds_matrix(&mut self) {
self.state = if SC::FULL_MDS {
self.params
.mds
.iter()
.map(|m| {
self.state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
} else {
vec![
self.state[0] + &self.state[2],
self.state[0] + &self.state[1],
self.state[1] + &self.state[2],
]
};
fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>, state: & Vec<F>) -> Vec<F> {
if SC::FULL_MDS {
params
.mds
.iter()
.map(|m| {
state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
} else {
vec![
state[0] + &state[2],
state[0] + &state[1],
state[1] + &state[2],
]
}
}

pub fn full_round(&mut self, r: usize) {
for i in 0..self.state.len() {
self.state[i] = sbox::<F, SC>(self.state[i]);
pub fn full_round<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>, r: usize) {
for i in 0..state.len() {
state[i] = sbox::<F, SC>(state[i]);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
}

fn half_rounds<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>) {
for r in 0..SC::HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
self.apply_mds_matrix();
for (i, x) in self.params.round_constants[r].iter().enumerate() {
self.state[i].add_assign(x);
for i in 0..state.len() {
state[i] = sbox::<F, SC>(state[i]);
}
apply_mds_matrix::<F, SC>(params, state);
}

fn half_rounds(&mut self) {
for r in 0..SC::HALF_ROUNDS_FULL {
for (i, x) in self.params.round_constants[r].iter().enumerate() {
self.state[i].add_assign(x);
}
for i in 0..self.state.len() {
self.state[i] = sbox::<F, SC>(self.state[i]);
}
self.apply_mds_matrix();
for r in 0..SC::ROUNDS_PARTIAL {
for (i, x) in params.round_constants[SC::HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
state[0] = sbox::<F, SC>(state[0]);
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::ROUNDS_PARTIAL {
for (i, x) in self.params.round_constants[SC::HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
self.state[i].add_assign(x);
}
self.state[0] = sbox::<F, SC>(self.state[0]);
self.apply_mds_matrix();
for r in 0..SC::HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[SC::HALF_ROUNDS_FULL + SC::ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
for i in 0..state.len() {
state[i] = sbox::<F, SC>(state[i]);
}
apply_mds_matrix::<F, SC>(params, state);
}
}

for r in 0..SC::HALF_ROUNDS_FULL {
for (i, x) in self.params.round_constants[SC::HALF_ROUNDS_FULL + SC::ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
self.state[i].add_assign(x);
pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>) {
if SC::HALF_ROUNDS_FULL == 0 {
if SC::INITIAL_ARK == true {
for (i, x) in params.round_constants[0].iter().enumerate() {
state[i].add_assign(x);
}
for r in 0..SC::ROUNDS_FULL {
full_round::<F, SC>(params, state, r + 1);
}
for i in 0..self.state.len() {
self.state[i] = sbox::<F, SC>(self.state[i]);
} else {
for r in 0..SC::ROUNDS_FULL {
full_round::<F, SC>(params, state, r);
}
self.apply_mds_matrix();
}
} else {
half_rounds::<F, SC>(params, state);
}
}

impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
pub fn full_round(&mut self, r: usize) {
full_round::<F, SC>(&self.params, &mut self.state, r);
}

fn poseidon_block_cipher(&mut self) {
if SC::HALF_ROUNDS_FULL == 0 {
if SC::INITIAL_ARK == true {
for (i, x) in self.params.round_constants[0].iter().enumerate() {
self.state[i].add_assign(x);
}
for r in 0..SC::ROUNDS_FULL {
self.full_round(r + 1);
}
} else {
for r in 0..SC::ROUNDS_FULL {
self.full_round(r);
}
}
} else {
self.half_rounds();
}
poseidon_block_cipher::<F, SC>(&self.params, &mut self.state);
}
}

Expand Down

0 comments on commit bacef43

Please sign in to comment.