diff --git a/vm/src/vm/runners/builtin_runner/modulo.rs b/vm/src/vm/runners/builtin_runner/modulo.rs index 572ccf1897..2324d42148 100644 --- a/vm/src/vm/runners/builtin_runner/modulo.rs +++ b/vm/src/vm/runners/builtin_runner/modulo.rs @@ -47,6 +47,7 @@ pub struct ModBuiltinRunner { // Precomputed powers used for reading and writing values that are represented as n_words words of word_bit_len bits each. shift: BigUint, shift_powers: [BigUint; N_WORDS], + k_bound: BigUint, } #[derive(Debug, Clone)] @@ -60,7 +61,7 @@ pub enum Operation { Mul, Add, Sub, - DivMod(BigUint), + DivMod, } impl Display for Operation { @@ -69,7 +70,7 @@ impl Display for Operation { Operation::Mul => "*".fmt(f), Operation::Add => "+".fmt(f), Operation::Sub => "-".fmt(f), - Operation::DivMod(_) => "/".fmt(f), + Operation::DivMod => "/".fmt(f), } } } @@ -85,17 +86,29 @@ struct Inputs { impl ModBuiltinRunner { pub(crate) fn new_add_mod(instance_def: &ModInstanceDef, included: bool) -> Self { - Self::new(instance_def.clone(), included, ModBuiltinType::Add) + Self::new( + instance_def.clone(), + included, + ModBuiltinType::Add, + Some(2u32.into()), + ) } pub(crate) fn new_mul_mod(instance_def: &ModInstanceDef, included: bool) -> Self { - Self::new(instance_def.clone(), included, ModBuiltinType::Mul) + Self::new(instance_def.clone(), included, ModBuiltinType::Mul, None) } - fn new(instance_def: ModInstanceDef, included: bool, builtin_type: ModBuiltinType) -> Self { + fn new( + instance_def: ModInstanceDef, + included: bool, + builtin_type: ModBuiltinType, + k_bound: Option, + ) -> Self { let shift = BigUint::one().shl(instance_def.word_bit_len); let shift_powers = core::array::from_fn(|i| shift.pow(i as u32)); let zero_segment_size = core::cmp::max(N_WORDS, instance_def.batch_size * 3); + let int_lim = BigUint::from(2_u32).pow(N_WORDS as u32 * instance_def.word_bit_len); + dbg!(&int_lim); Self { builtin_type, base: 0, @@ -106,6 +119,7 @@ impl ModBuiltinRunner { zero_segment_size, shift, shift_powers, + k_bound: k_bound.unwrap_or(int_lim), } } @@ -458,19 +472,19 @@ impl ModBuiltinRunner { match (a, b, c) { // Deduce c from a and b and write it to memory. (Some(a), Some(b), None) => { - let value = apply_op(a, b, op)?.mod_floor(&inputs.p); + let value = apply_op(op, a, b, &inputs.p, &self.k_bound)?; self.write_n_words_value(memory, addresses[2], value)?; Ok(true) } // Deduce b from a and c and write it to memory. (Some(a), None, Some(c)) => { - let value = apply_op(c, a, inv_op)?.mod_floor(&inputs.p); + let value = apply_op(inv_op, c, a, &inputs.p, &self.k_bound)?; self.write_n_words_value(memory, addresses[1], value)?; Ok(true) } // Deduce a from b and c and write it to memory. (None, Some(b), Some(c)) => { - let value = apply_op(c, b, inv_op)?.mod_floor(&inputs.p); + let value = apply_op(inv_op, c, b, &inputs.p, &self.k_bound)?; self.write_n_words_value(memory, addresses[0], value)?; Ok(true) } @@ -539,44 +553,45 @@ impl ModBuiltinRunner { Default::default() }; - // Get one of the builtin runners - the rest of this function doesn't depend on batch_size. - let mod_runner = if let Some((_, add_mod, _)) = add_mod { - add_mod - } else { - mul_mod.unwrap().1 - }; // Fill the values table. let mut add_mod_index = 0; let mut mul_mod_index = 0; - // Create operation here to avoid cloning p in the loop - let div_operation = Operation::DivMod(mul_mod_inputs.p.clone()); + while add_mod_index < add_mod_n || mul_mod_index < mul_mod_n { - if add_mod_index < add_mod_n - && mod_runner.fill_value( - memory, - &add_mod_inputs, - add_mod_index, - &Operation::Add, - &Operation::Sub, - )? - { - add_mod_index += 1; - } else if mul_mod_index < mul_mod_n - && mod_runner.fill_value( - memory, - &mul_mod_inputs, - mul_mod_index, - &Operation::Mul, - &div_operation, - )? - { - mul_mod_index += 1; - } else { - return Err(RunnerError::FillMemoryCoudNotFillTable( - add_mod_index, - mul_mod_index, - )); + if add_mod_index < add_mod_n { + if let Some((_, add_mod_runner, _)) = add_mod { + if add_mod_runner.fill_value( + memory, + &add_mod_inputs, + add_mod_index, + &Operation::Add, + &Operation::Sub, + )? { + add_mod_index += 1; + continue; + } + } } + + if mul_mod_index < mul_mod_n { + if let Some((_, mul_mod_runner, _)) = mul_mod { + if mul_mod_runner.fill_value( + memory, + &mul_mod_inputs, + mul_mod_index, + &Operation::Mul, + &Operation::DivMod, + )? { + mul_mod_index += 1; + } + continue; + } + } + + return Err(RunnerError::FillMemoryCoudNotFillTable( + add_mod_index, + mul_mod_index, + )); } Ok(()) } @@ -629,7 +644,7 @@ impl ModBuiltinRunner { ModBuiltinType::Add => Operation::Add, ModBuiltinType::Mul => Operation::Mul, }; - let a_op_b = apply_op(&a, &b, &op)?.mod_floor(&inputs.p); + let a_op_b = apply_op(&op, &a, &b, &inputs.p, &self.k_bound)?; if a_op_b != c.mod_floor(&inputs.p) { // Build error string let p = inputs.p; @@ -669,13 +684,46 @@ impl ModBuiltinRunner { } } -fn apply_op(lhs: &BigUint, rhs: &BigUint, op: &Operation) -> Result { - Ok(match op { - Operation::Mul => lhs * rhs, - Operation::Add => lhs + rhs, - Operation::Sub => lhs - rhs, - Operation::DivMod(ref p) => div_mod_unsigned(lhs, rhs, p)?, - }) +fn apply_op( + op: &Operation, + lhs: &BigUint, + rhs: &BigUint, + p: &BigUint, + k_bound: &BigUint, +) -> Result { + println!("== GATE =="); + println!("op = {:?}", op); + println!("lhs = {:?}", lhs); + println!("rhs = {:?}", rhs); + println!("k_bound = {:?}", k_bound); + let value = match op { + Operation::Mul => { + let value = lhs * rhs; + if value < k_bound * p { + value.mod_floor(p) + } else { + value - (k_bound - 1u32) * p + } + } + Operation::Add => { + let value = lhs + rhs; + if value < k_bound * p { + value.mod_floor(p) + } else { + value - (k_bound - 1u32) * p + } + } + Operation::Sub => { + if rhs <= lhs { + lhs - rhs + } else { + lhs + p - rhs + } + } + Operation::DivMod => div_mod_unsigned(lhs, rhs, p)?, + }; + println!("value = {:?}", value); + Ok(value) } #[cfg(test)]