Skip to content

Commit

Permalink
Remove extra SHL/SHR CTL. (#1270)
Browse files Browse the repository at this point in the history
* Remove extra shift CTL.

* Change order of inputs for the arithmetic shift operations. Add SHR test. Fix max number of bit shifts. Cleanup.

* Fix SHR in the case shift >= 256

* Limit visibility of helper functions
  • Loading branch information
LindaGuiga authored Oct 5, 2023
1 parent 51eb7c0 commit 0de6f94
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 106 deletions.
5 changes: 1 addition & 4 deletions evm/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {

fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![
cpu_stark::ctl_arithmetic_base_rows(),
cpu_stark::ctl_arithmetic_shift_rows(),
],
vec![cpu_stark::ctl_arithmetic_base_rows()],
arithmetic_stark::ctl_arithmetic_rows(),
)
}
Expand Down
3 changes: 3 additions & 0 deletions evm/src/arithmetic/arithmetic_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use plonky2::util::transpose;
use static_assertions::const_assert;

use super::columns::NUM_ARITH_COLUMNS;
use super::shift;
use crate::all_stark::Table;
use crate::arithmetic::columns::{RANGE_COUNTER, RC_FREQUENCIES, SHARED_COLS};
use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation};
Expand Down Expand Up @@ -208,6 +209,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
divmod::eval_packed(lv, nv, yield_constr);
modular::eval_packed(lv, nv, yield_constr);
byte::eval_packed(lv, yield_constr);
shift::eval_packed_generic(lv, nv, yield_constr);
}

fn eval_ext_circuit(
Expand Down Expand Up @@ -237,6 +239,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
divmod::eval_ext_circuit(builder, lv, nv, yield_constr);
modular::eval_ext_circuit(builder, lv, nv, yield_constr);
byte::eval_ext_circuit(builder, lv, yield_constr);
shift::eval_ext_circuit(builder, lv, nv, yield_constr);
}

fn constraint_degree(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion evm/src/arithmetic/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_REGISTER_0;
pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start;
pub(crate) const MODULAR_AUX_INPUT_LO: Range<usize> = AUX_REGISTER_1.start + 1..AUX_REGISTER_1.end;
pub(crate) const MODULAR_AUX_INPUT_HI: Range<usize> = AUX_REGISTER_2;
// Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV]
// Must be set to MOD_IS_ZERO for DIV and SHR operations i.e. MOD_IS_ZERO * (lv[IS_DIV] + lv[IS_SHR]).
pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end;

/// The counter column (used for the range check) starts from 0 and increments.
Expand Down
77 changes: 49 additions & 28 deletions evm/src/arithmetic/divmod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,19 @@ use crate::arithmetic::modular::{
use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
/// Generates the output and auxiliary values for modular operations,
/// assuming the input, modular and output limbs are already set.
pub(crate) fn generate_divmod<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
input_limbs_range: Range<usize>,
modulus_range: Range<usize>,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);

u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);

let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, input_limbs_range);
let pol_input = pol_extend(input_limbs);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, modulus_range);

debug_assert!(
&quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO),
"expected top half of quo_input to be zero"
Expand Down Expand Up @@ -62,16 +57,35 @@ pub(crate) fn generate<F: PrimeField64>(
);
lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]);
}
_ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"),
_ => panic!("expected filter to be IS_DIV, IS_SHR or IS_MOD but it was {filter}"),
};
}
/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);

u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);

generate_divmod(lv, nv, filter, INPUT_REGISTER_0, INPUT_REGISTER_1);
}

/// Verify that num = quo * den + rem and 0 <= rem < den.
fn eval_packed_divmod_helper<P: PackedField>(
pub(crate) fn eval_packed_divmod_helper<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
filter: P,
num_range: Range<usize>,
den_range: Range<usize>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
Expand All @@ -80,8 +94,8 @@ fn eval_packed_divmod_helper<P: PackedField>(

yield_constr.constraint_last_row(filter);

let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let num = &lv[num_range];
let den = read_value(lv, den_range);
let quo = {
let mut quo = [P::ZEROS; 2 * N_LIMBS];
quo[..N_LIMBS].copy_from_slice(&lv[quo_range]);
Expand All @@ -104,14 +118,13 @@ pub(crate) fn eval_packed<P: PackedField>(
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
// Constrain IS_SHR independently, so that it doesn't impact the
// constraints when combining the flag with IS_DIV.
yield_constr.constraint_last_row(lv[IS_SHR]);
eval_packed_divmod_helper(
lv,
nv,
yield_constr,
lv[IS_DIV] + lv[IS_SHR],
lv[IS_DIV],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
Expand All @@ -120,24 +133,28 @@ pub(crate) fn eval_packed<P: PackedField>(
nv,
yield_constr,
lv[IS_MOD],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
}

fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
pub(crate) fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>,
num_range: Range<usize>,
den_range: Range<usize>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
yield_constr.constraint_last_row(builder, filter);

let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let num = &lv[num_range];
let den = read_value(lv, den_range);
let quo = {
let zero = builder.zero_extension();
let mut quo = [zero; 2 * N_LIMBS];
Expand All @@ -164,14 +181,14 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
yield_constr.constraint_last_row(builder, lv[IS_SHR]);
let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]);
eval_ext_circuit_divmod_helper(
builder,
lv,
nv,
yield_constr,
div_shr_flag,
lv[IS_DIV],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
Expand All @@ -181,6 +198,8 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv,
yield_constr,
lv[IS_MOD],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
Expand Down Expand Up @@ -214,7 +233,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Deactivate the SHR flag so that a DIV operation is not triggered.
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;

let mut constraint_consumer = ConstraintConsumer::new(
Expand Down Expand Up @@ -247,6 +266,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
lv[op_filter] = F::ONE;

Expand Down Expand Up @@ -308,6 +328,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
lv[op_filter] = F::ONE;

Expand Down
33 changes: 29 additions & 4 deletions evm/src/arithmetic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod byte;
mod divmod;
mod modular;
mod mul;
mod shift;
mod utils;

pub mod arithmetic_stark;
Expand All @@ -35,15 +36,29 @@ impl BinaryOperator {
pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 {
match self {
BinaryOperator::Add => input0.overflowing_add(input1).0,
BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0,
BinaryOperator::Mul => input0.overflowing_mul(input1).0,
BinaryOperator::Shl => {
if input0 < U256::from(256usize) {
input1 << input0
} else {
U256::zero()
}
}
BinaryOperator::Sub => input0.overflowing_sub(input1).0,
BinaryOperator::Div | BinaryOperator::Shr => {
BinaryOperator::Div => {
if input1.is_zero() {
U256::zero()
} else {
input0 / input1
}
}
BinaryOperator::Shr => {
if input0 < U256::from(256usize) {
input1 >> input0
} else {
U256::zero()
}
}
BinaryOperator::Mod => {
if input1.is_zero() {
U256::zero()
Expand Down Expand Up @@ -238,15 +253,25 @@ fn binary_op_to_rows<F: PrimeField64>(
addcy::generate(&mut row, op.row_filter(), input0, input1);
(row, None)
}
BinaryOperator::Mul | BinaryOperator::Shl => {
BinaryOperator::Mul => {
mul::generate(&mut row, input0, input1);
(row, None)
}
BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => {
BinaryOperator::Shl => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
shift::generate(&mut row, &mut nv, true, input0, input1, result);
(row, None)
}
BinaryOperator::Div | BinaryOperator::Mod => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result);
(row, Some(nv))
}
BinaryOperator::Shr => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
shift::generate(&mut row, &mut nv, false, input0, input1, result);
(row, Some(nv))
}
BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => {
ternary_op_to_rows::<F>(op.row_filter(), input0, input1, BN_BASE, result)
}
Expand Down
24 changes: 15 additions & 9 deletions evm/src/arithmetic/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub(crate) fn generate_modular_op<F: PrimeField64>(

let mut mod_is_zero = F::ZERO;
if modulus.is_zero() {
if filter == columns::IS_DIV {
if filter == columns::IS_DIV || filter == columns::IS_SHR {
// set modulus = 2^256; the condition above means we know
// it's zero at this point, so we can just set bit 256.
modulus.set_bit(256, true);
Expand Down Expand Up @@ -330,7 +330,7 @@ pub(crate) fn generate_modular_op<F: PrimeField64>(

nv[MODULAR_MOD_IS_ZERO] = mod_is_zero;
nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64));
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV];
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]);

(
output_limbs.map(F::from_canonical_i64),
Expand Down Expand Up @@ -392,14 +392,14 @@ pub(crate) fn check_reduced<P: PackedField>(
// Verify that the output is reduced, i.e. output < modulus.
let out_aux_red = &nv[MODULAR_OUT_AUX_RED];
// This sets is_less_than to 1 unless we get mod_is_zero when
// doing a DIV; in that case, we need is_less_than=0, since
// doing a DIV or SHR; in that case, we need is_less_than=0, since
// eval_packed_generic_addcy checks
//
// modulus + out_aux_red == output + is_less_than*2^256
//
// and we are given output = out_aux_red when modulus is zero.
let mut is_less_than = [P::ZEROS; N_LIMBS];
is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV];
is_less_than[0] = P::ONES - mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]);
// NB: output and modulus in lv while out_aux_red and
// is_less_than (via mod_is_zero) depend on nv, hence the
// 'is_two_row_op' argument is set to 'true'.
Expand Down Expand Up @@ -448,13 +448,15 @@ pub(crate) fn modular_constr_poly<P: PackedField>(
// modulus = 0.
modulus[0] += mod_is_zero;

// Is 1 iff the operation is DIV and the denominator is zero.
// Is 1 iff the operation is DIV or SHR and the denominator is zero.
let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero));
yield_constr.constraint_transition(
filter * (mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]) - div_denom_is_zero),
);

// Needed to compensate for adding mod_is_zero to modulus above,
// since the call eval_packed_generic_addcy() below subtracts modulus
// to verify in the case of a DIV.
// to verify in the case of a DIV or SHR.
output[0] += div_denom_is_zero;

check_reduced(lv, nv, yield_constr, filter, output, modulus, mod_is_zero);
Expand Down Expand Up @@ -635,7 +637,8 @@ pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, cons
modulus[0] = builder.add_extension(modulus[0], mod_is_zero);

let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero);
let div_shr_filter = builder.add_extension(lv[IS_DIV], lv[IS_SHR]);
let t = builder.mul_sub_extension(mod_is_zero, div_shr_filter, div_denom_is_zero);
let t = builder.mul_extension(filter, t);
yield_constr.constraint_transition(builder, t);
output[0] = builder.add_extension(output[0], div_denom_is_zero);
Expand All @@ -645,7 +648,7 @@ pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, cons
let zero = builder.zero_extension();
let mut is_less_than = [zero; N_LIMBS];
is_less_than[0] =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one);
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, div_shr_filter, one);

eval_ext_circuit_addcy(
builder,
Expand Down Expand Up @@ -834,6 +837,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;

Expand Down Expand Up @@ -867,6 +871,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE;
Expand Down Expand Up @@ -926,6 +931,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE;
Expand Down
Loading

0 comments on commit 0de6f94

Please sign in to comment.