Skip to content

Commit

Permalink
feat: simplify constant MSM calls in SSA (#6547)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Nov 19, 2024
1 parent 245f50d commit f291e37
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 2 deletions.
1 change: 0 additions & 1 deletion acvm-repo/bn254_blackbox_solver/src/embedded_curve_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub fn multi_scalar_mul(
scalars_hi: &[FieldElement],
) -> Result<(FieldElement, FieldElement, FieldElement), BlackBoxResolutionError> {
if points.len() != 3 * scalars_lo.len() || scalars_lo.len() != scalars_hi.len() {
dbg!(&points.len(), &scalars_lo.len(), &scalars_hi.len());
return Err(BlackBoxResolutionError::Failed(
BlackBoxFunc::MultiScalarMul,
"Points and scalars must have the same length".to_string(),
Expand Down
4 changes: 3 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,9 @@ fn simplify_black_box_func(
acvm::blackbox_solver::ecdsa_secp256r1_verify,
),

BlackBoxFunc::MultiScalarMul => SimplifyResult::None,
BlackBoxFunc::MultiScalarMul => {
blackbox::simplify_msm(dfg, solver, arguments, block, call_stack)
}
BlackBoxFunc::EmbeddedCurveAdd => {
blackbox::simplify_ec_add(dfg, solver, arguments, block, call_stack)
}
Expand Down
58 changes: 58 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,64 @@ pub(super) fn simplify_ec_add(
}
}

pub(super) fn simplify_msm(
dfg: &mut DataFlowGraph,
solver: impl BlackBoxFunctionSolver<FieldElement>,
arguments: &[ValueId],
block: BasicBlockId,
call_stack: &CallStack,
) -> SimplifyResult {
// TODO: Handle MSMs where a subset of the terms are constant.
match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) {
(Some((points, _)), Some((scalars, _))) => {
let Some(points) = points
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};

let Some(scalars) = scalars
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};

let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (i, scalar) in scalars.into_iter().enumerate() {
if i % 2 == 0 {
scalars_lo.push(scalar);
} else {
scalars_hi.push(scalar);
}
}

let Ok((result_x, result_y, result_is_infinity)) =
solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
else {
return SimplifyResult::None;
};

let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array =
dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone());

SimplifyResult::SimplifiedTo(result_array.first())
}
_ => SimplifyResult::None,
}
}

pub(super) fn simplify_poseidon2_permutation(
dfg: &mut DataFlowGraph,
solver: impl BlackBoxFunctionSolver<FieldElement>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "embedded_curve_msm_simplification"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fn main() {
let pub_x = 0x0000000000000000000000000000000000000000000000000000000000000001;
let pub_y = 0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c;

let g1_y = 17631683881184975370165255887551781615748388533673675138860;
let g1 = std::embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: g1_y, is_infinite: false };
let scalar = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 1, hi: 0 };
// Test that multi_scalar_mul correctly derives the public key
let res = std::embedded_curve_ops::multi_scalar_mul([g1], [scalar]);
assert(res.x == pub_x);
assert(res.y == pub_y);
}

0 comments on commit f291e37

Please sign in to comment.