From f969fb4d377162bd00c629acaa28027d048d8e8e Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 6 Aug 2021 20:37:06 +0200 Subject: [PATCH 01/78] apply propagation in zir --- changelogs/unreleased/957-dark64 | 1 + zokrates_core/src/static_analysis/mod.rs | 4 + .../src/static_analysis/zir_propagation.rs | 516 ++++++++++++++++++ zokrates_core/src/zir/mod.rs | 12 +- zokrates_core/src/zir/uint.rs | 2 +- 5 files changed, 528 insertions(+), 7 deletions(-) create mode 100644 changelogs/unreleased/957-dark64 create mode 100644 zokrates_core/src/static_analysis/zir_propagation.rs diff --git a/changelogs/unreleased/957-dark64 b/changelogs/unreleased/957-dark64 new file mode 100644 index 000000000..2c4f7389d --- /dev/null +++ b/changelogs/unreleased/957-dark64 @@ -0,0 +1 @@ +Apply propagation in ZIR \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index e66f15055..686eca1f3 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -14,6 +14,7 @@ mod shift_checker; mod uint_optimizer; mod unconstrained_vars; mod variable_write_remover; +mod zir_propagation; use self::branch_isolator::Isolator; use self::flatten_complex_types::Flattener; @@ -27,6 +28,7 @@ use crate::compile::CompileConfig; use crate::flat_absy::FlatProg; use crate::ir::Prog; use crate::static_analysis::constant_inliner::ConstantInliner; +use crate::static_analysis::zir_propagation::ZirPropagator; use crate::typed_absy::{abi::Abi, TypedProgram}; use crate::zir::ZirProgram; use std::fmt; @@ -94,6 +96,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = ShiftChecker::check(r).map_err(Error::from)?; // convert to zir, removing complex types let zir = Flattener::flatten(r); + // apply propagation to zir + let zir = ZirPropagator::propagate(zir); // optimize uint expressions let zir = UintOptimizer::optimize(zir); diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs new file mode 100644 index 000000000..35e1b94be --- /dev/null +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -0,0 +1,516 @@ +use crate::zir::{ + BooleanExpression, FieldElementExpression, Folder, UExpression, UExpressionInner, Variable, + ZirExpression, ZirFunction, ZirProgram, ZirStatement, +}; +use std::collections::HashMap; +use zokrates_field::Field; + +type Constants<'ast, T> = HashMap, ZirExpression<'ast, T>>; + +trait Propagator<'ast, T> { + type Output; + fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output; +} + +#[derive(Default)] +pub struct ZirPropagator; + +impl ZirPropagator { + pub fn new() -> Self { + ZirPropagator::default() + } + pub fn propagate(p: ZirProgram) -> ZirProgram { + ZirPropagator::new().fold_program(p) + } +} + +impl<'ast, T: Field> Propagator<'ast, T> for ZirStatement<'ast, T> { + type Output = Option; + + fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { + match self { + ZirStatement::Assertion(e) => match e.propagate(constants) { + BooleanExpression::Value(true) => None, + e => Some(ZirStatement::Assertion(e)), + }, + ZirStatement::Definition(a, e) => { + let e = e.propagate(constants); + match e { + ZirExpression::FieldElement(FieldElementExpression::Number(_)) + | ZirExpression::Boolean(BooleanExpression::Value(_)) => { + constants.insert(a, e); + None + } + ZirExpression::Uint(e) => match e.inner { + UExpressionInner::Value(_) => { + constants.insert(a, ZirExpression::Uint(e)); + None + } + _ => Some(ZirStatement::Definition(a, ZirExpression::Uint(e))), + }, + _ => Some(ZirStatement::Definition(a, e)), + } + } + ZirStatement::IfElse(e, consequence, alternative) => Some(ZirStatement::IfElse( + e.propagate(constants), + consequence + .into_iter() + .filter_map(|s| s.propagate(constants)) + .collect(), + alternative + .into_iter() + .filter_map(|s| s.propagate(constants)) + .collect(), + )), + ZirStatement::Return(e) => Some(ZirStatement::Return( + e.into_iter().map(|e| e.propagate(constants)).collect(), + )), + ZirStatement::MultipleDefinition(assignees, list) => { + // TODO: apply propagation here + Some(ZirStatement::MultipleDefinition(assignees, list)) + } + } + } +} + +impl<'ast, T: Field> Propagator<'ast, T> for ZirExpression<'ast, T> { + type Output = Self; + + fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { + match self { + ZirExpression::Boolean(e) => ZirExpression::Boolean(e.propagate(constants)), + ZirExpression::FieldElement(e) => ZirExpression::FieldElement(e.propagate(constants)), + ZirExpression::Uint(e) => ZirExpression::Uint(e.propagate(constants)), + } + } +} + +impl<'ast, T: Field> Propagator<'ast, T> for UExpression<'ast, T> { + type Output = Self; + + fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { + UExpression { + inner: match self.inner { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + UExpressionInner::Identifier(id) => { + match constants.get(&Variable::uint(id.clone(), self.bitwidth)) { + Some(ZirExpression::Uint(e)) => match e.inner { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + _ => unreachable!("should contain constant uint value"), + }, + _ => UExpressionInner::Identifier(id), + } + } + UExpressionInner::Select(e, box index) => { + let index = index.propagate(constants); + match index.inner { + UExpressionInner::Value(v) => e + .get(v as usize) + .cloned() + .expect("index out of bounds") + .into_inner(), + _ => UExpressionInner::Select(e, box index), + } + } + UExpressionInner::Add(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 + n2) + } + _ => UExpressionInner::Add(box e1, box e2), + } + } + UExpressionInner::Sub(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 - n2) + } + _ => UExpressionInner::Sub(box e1, box e2), + } + } + UExpressionInner::Mult(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 * n2) + } + _ => UExpressionInner::Mult(box e1, box e2), + } + } + UExpressionInner::Div(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 / n2) + } + _ => UExpressionInner::Div(box e1, box e2), + } + } + UExpressionInner::Rem(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 % n2) + } + _ => UExpressionInner::Rem(box e1, box e2), + } + } + UExpressionInner::Xor(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 ^ n2) + } + _ => UExpressionInner::Xor(box e1, box e2), + } + } + UExpressionInner::And(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 & n2) + } + _ => UExpressionInner::And(box e1, box e2), + } + } + UExpressionInner::Or(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 | n2) + } + _ => UExpressionInner::Or(box e1, box e2), + } + } + UExpressionInner::LeftShift(box e, v) => { + let e = e.propagate(constants); + match &e.inner { + UExpressionInner::Value(n) => UExpressionInner::Value(n << v), + _ => UExpressionInner::LeftShift(box e, v), + } + } + UExpressionInner::RightShift(box e, v) => { + let e = e.propagate(constants); + match &e.inner { + UExpressionInner::Value(n) => UExpressionInner::Value(n >> v), + _ => UExpressionInner::RightShift(box e, v), + } + } + UExpressionInner::Not(box e) => { + let e = e.propagate(constants); + match &e.inner { + UExpressionInner::Value(n) => UExpressionInner::Value(!*n), + _ => UExpressionInner::Not(box e), + } + } + UExpressionInner::IfElse(box condition, box consequence, box alternative) => { + let condition = condition.propagate(constants); + match condition { + BooleanExpression::Value(true) => consequence.into_inner(), + BooleanExpression::Value(false) => alternative.into_inner(), + _ => UExpressionInner::IfElse( + box condition, + box consequence, + box alternative, + ), + } + } + }, + ..self + } + } +} + +impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { + type Output = Self; + + fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { + match self { + FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Identifier(id) => { + match constants.get(&Variable::field_element(id.clone())) { + Some(ZirExpression::FieldElement(FieldElementExpression::Number(v))) => { + FieldElementExpression::Number(v.clone()) + } + _ => FieldElementExpression::Identifier(id), + } + } + FieldElementExpression::Select(e, box index) => { + let index = index.propagate(constants); + match index.inner { + UExpressionInner::Value(v) => { + e.get(v as usize).cloned().expect("index out of bounds") + } + _ => FieldElementExpression::Select(e, box index), + } + } + FieldElementExpression::Add(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + FieldElementExpression::Number(n1 + n2) + } + (e1, e2) => FieldElementExpression::Add(box e1, box e2), + } + } + FieldElementExpression::Sub(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + FieldElementExpression::Number(n1 - n2) + } + (e1, e2) => FieldElementExpression::Sub(box e1, box e2), + } + } + FieldElementExpression::Mult(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + FieldElementExpression::Number(n1 * n2) + } + (e1, e2) => FieldElementExpression::Mult(box e1, box e2), + } + } + FieldElementExpression::Div(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + FieldElementExpression::Number(n1 / n2) + } + (e1, e2) => FieldElementExpression::Div(box e1, box e2), + } + } + FieldElementExpression::Pow(box e, box exponent) => { + let exponent = exponent.propagate(constants); + match (e.propagate(constants), &exponent.inner) { + (_, UExpressionInner::Value(n2)) if *n2 == 0 => { + FieldElementExpression::Number(T::from(1)) + } + (FieldElementExpression::Number(n), UExpressionInner::Value(e)) => { + FieldElementExpression::Number(n.pow(*e as usize)) + } + (e, _) => FieldElementExpression::Pow(box e, box exponent), + } + } + FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { + let condition = condition.propagate(constants); + match condition { + BooleanExpression::Value(true) => consequence, + BooleanExpression::Value(false) => alternative, + _ => FieldElementExpression::IfElse( + box condition, + box consequence, + box alternative, + ), + } + } + } + } +} + +impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { + type Output = Self; + + fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { + match self { + BooleanExpression::Value(v) => BooleanExpression::Value(v), + BooleanExpression::Identifier(id) => { + match constants.get(&Variable::boolean(id.clone())) { + Some(ZirExpression::Boolean(BooleanExpression::Value(v))) => { + BooleanExpression::Value(*v) + } + _ => BooleanExpression::Identifier(id), + } + } + BooleanExpression::Select(e, box index) => { + let index = index.propagate(constants); + match index.inner { + UExpressionInner::Value(v) => { + e.get(v as usize).cloned().expect("index out of bounds") + } + _ => BooleanExpression::Select(e, box index), + } + } + BooleanExpression::FieldLt(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + BooleanExpression::Value(n1 < n2) + } + (e1, e2) => BooleanExpression::FieldLt(box e1, box e2), + } + } + BooleanExpression::FieldLe(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + BooleanExpression::Value(n1 <= n2) + } + (e1, e2) => BooleanExpression::FieldLe(box e1, box e2), + } + } + BooleanExpression::FieldGe(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + BooleanExpression::Value(n1 >= n2) + } + (e1, e2) => BooleanExpression::FieldGe(box e1, box e2), + } + } + BooleanExpression::FieldGt(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { + BooleanExpression::Value(n1 > n2) + } + (e1, e2) => BooleanExpression::FieldGt(box e1, box e2), + } + } + BooleanExpression::FieldEq(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (FieldElementExpression::Number(v1), FieldElementExpression::Number(v2)) => { + BooleanExpression::Value(v1.eq(&v2)) + } + (e1, e2) => { + if e1.eq(&e2) { + BooleanExpression::Value(true) + } else { + BooleanExpression::FieldEq(box e1, box e2) + } + } + } + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { + BooleanExpression::Value(v1 < v2) + } + _ => BooleanExpression::UintLt(box e1, box e2), + } + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { + BooleanExpression::Value(v1 <= v2) + } + _ => BooleanExpression::UintLe(box e1, box e2), + } + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { + BooleanExpression::Value(v1 >= v2) + } + _ => BooleanExpression::UintGe(box e1, box e2), + } + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { + BooleanExpression::Value(v1 > v2) + } + _ => BooleanExpression::UintGt(box e1, box e2), + } + } + BooleanExpression::UintEq(box e1, box e2) => { + let e1 = e1.propagate(constants); + let e2 = e2.propagate(constants); + + match (&e1.inner, &e2.inner) { + (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { + BooleanExpression::Value(v1 == v2) + } + _ => { + if e1.eq(&e2) { + BooleanExpression::Value(true) + } else { + BooleanExpression::UintEq(box e1, box e2) + } + } + } + } + BooleanExpression::BoolEq(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { + BooleanExpression::Value(v1 == v2) + } + (e1, e2) => { + if e1.eq(&e2) { + BooleanExpression::Value(true) + } else { + BooleanExpression::BoolEq(box e1, box e2) + } + } + } + } + BooleanExpression::Or(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { + BooleanExpression::Value(v1 || v2) + } + (e1, e2) => BooleanExpression::Or(box e1, box e2), + } + } + BooleanExpression::And(box e1, box e2) => { + match (e1.propagate(constants), e2.propagate(constants)) { + (BooleanExpression::Value(true), e) | (e, BooleanExpression::Value(true)) => e, + (BooleanExpression::Value(false), _) | (_, BooleanExpression::Value(false)) => { + BooleanExpression::Value(false) + } + (e1, e2) => BooleanExpression::And(box e1, box e2), + } + } + BooleanExpression::Not(box e) => match e.propagate(constants) { + BooleanExpression::Value(v) => BooleanExpression::Value(!v), + e => BooleanExpression::Not(box e), + }, + BooleanExpression::IfElse(box condition, box consequence, box alternative) => { + let condition = condition.propagate(constants); + match condition { + BooleanExpression::Value(true) => consequence, + BooleanExpression::Value(false) => alternative, + _ => BooleanExpression::IfElse(box condition, box consequence, box alternative), + } + } + } + } +} + +impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator { + fn fold_function(&mut self, f: ZirFunction<'ast, T>) -> ZirFunction<'ast, T> { + let mut constants: HashMap, ZirExpression<'ast, T>> = HashMap::new(); + + ZirFunction { + signature: f.signature, + arguments: f.arguments, + statements: f + .statements + .into_iter() + .filter_map(|s| s.propagate(&mut constants)) + .collect(), + } + } +} diff --git a/zokrates_core/src/zir/mod.rs b/zokrates_core/src/zir/mod.rs index cf8ce0513..a3bbe705d 100644 --- a/zokrates_core/src/zir/mod.rs +++ b/zokrates_core/src/zir/mod.rs @@ -294,8 +294,8 @@ pub enum FieldElementExpression<'ast, T> { /// An expression of type `bool` #[derive(Clone, PartialEq, Hash, Eq, Debug)] pub enum BooleanExpression<'ast, T> { - Identifier(Identifier<'ast>), Value(bool), + Identifier(Identifier<'ast>), Select(Vec, Box>), FieldLt( Box>, @@ -313,19 +313,19 @@ pub enum BooleanExpression<'ast, T> { Box>, Box>, ), - UintLt(Box>, Box>), - UintLe(Box>, Box>), - UintGe(Box>, Box>), - UintGt(Box>, Box>), FieldEq( Box>, Box>, ), + UintLt(Box>, Box>), + UintLe(Box>, Box>), + UintGe(Box>, Box>), + UintGt(Box>, Box>), + UintEq(Box>, Box>), BoolEq( Box>, Box>, ), - UintEq(Box>, Box>), Or( Box>, Box>, diff --git a/zokrates_core/src/zir/uint.rs b/zokrates_core/src/zir/uint.rs index 04b043ba2..ac944ca2d 100644 --- a/zokrates_core/src/zir/uint.rs +++ b/zokrates_core/src/zir/uint.rs @@ -163,8 +163,8 @@ pub struct UExpression<'ast, T> { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum UExpressionInner<'ast, T> { - Identifier(Identifier<'ast>), Value(u128), + Identifier(Identifier<'ast>), Select(Vec>, Box>), Add(Box>, Box>), Sub(Box>, Box>), From e1f1cda27a12a6a9f1dc5ff82eb05a0da1459694 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 12 Aug 2021 16:15:27 +0200 Subject: [PATCH 02/78] use folder in zir propagation --- zokrates_core/src/flatten/mod.rs | 1 + .../src/static_analysis/uint_optimizer.rs | 119 ++-- .../src/static_analysis/zir_propagation.rs | 629 ++++++++++-------- 3 files changed, 406 insertions(+), 343 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 126190b8e..d9423b863 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1447,6 +1447,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // whether this value should be reduced, for example if it is then used in a bitwidth operation let should_reduce = metadata.should_reduce; + assert!(!should_reduce.is_unknown(), "should_reduce should be known"); let should_reduce = should_reduce.to_bool(); diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index ecca499a0..d651f2a47 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -456,77 +456,70 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { ZirStatement::MultipleDefinition( lhs, ZirExpressionList::EmbedCall(embed, generics, arguments), - ) => match embed { - FlatEmbed::U64FromBits => { - assert_eq!(lhs.len(), 1); - self.register( - lhs[0].clone(), - UMetadata { - max: T::from(2).pow(64) - T::from(1), - should_reduce: ShouldReduce::False, - }, - ); - - vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall(embed, generics, arguments), - )] - } - FlatEmbed::U32FromBits => { - assert_eq!(lhs.len(), 1); - self.register( - lhs[0].clone(), - UMetadata { - max: T::from(2).pow(32) - T::from(1), - should_reduce: ShouldReduce::False, - }, - ); - - vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall(embed, generics, arguments), - )] - } - FlatEmbed::U16FromBits => { - assert_eq!(lhs.len(), 1); - self.register( - lhs[0].clone(), - UMetadata { - max: T::from(2).pow(16) - T::from(1), - should_reduce: ShouldReduce::False, - }, - ); - vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall(embed, generics, arguments), - )] - } - FlatEmbed::U8FromBits => { - assert_eq!(lhs.len(), 1); - self.register( - lhs[0].clone(), - UMetadata { - max: T::from(2).pow(8) - T::from(1), - should_reduce: ShouldReduce::False, - }, - ); - vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall(embed, generics, arguments), - )] - } - _ => vec![ZirStatement::MultipleDefinition( + ) => { + match embed { + FlatEmbed::U64FromBits => { + assert_eq!(lhs.len(), 1); + self.register( + lhs[0].clone(), + UMetadata { + max: T::from(2).pow(64) - T::from(1), + should_reduce: ShouldReduce::False, + }, + ); + } + FlatEmbed::U32FromBits => { + assert_eq!(lhs.len(), 1); + self.register( + lhs[0].clone(), + UMetadata { + max: T::from(2).pow(32) - T::from(1), + should_reduce: ShouldReduce::False, + }, + ); + } + FlatEmbed::U16FromBits => { + assert_eq!(lhs.len(), 1); + self.register( + lhs[0].clone(), + UMetadata { + max: T::from(2).pow(16) - T::from(1), + should_reduce: ShouldReduce::False, + }, + ); + } + FlatEmbed::U8FromBits => { + assert_eq!(lhs.len(), 1); + self.register( + lhs[0].clone(), + UMetadata { + max: T::from(2).pow(8) - T::from(1), + should_reduce: ShouldReduce::False, + }, + ); + } + _ => {} + }; + + vec![ZirStatement::MultipleDefinition( lhs, ZirExpressionList::EmbedCall( embed, generics, arguments .into_iter() - .map(|e| self.fold_expression(e)) + .map(|e| match e { + ZirExpression::Uint(e) => { + let e = self.fold_uint_expression(e); + let e = force_no_reduce(e); + ZirExpression::Uint(e) + } + e => self.fold_expression(e), + }) .collect(), ), - )], - }, + )] + } ZirStatement::Assertion(BooleanExpression::UintEq(box left, box right)) => { let left = self.fold_uint_expression(left); let right = self.fold_uint_expression(right); diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 35e1b94be..4ee085ebe 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -1,252 +1,74 @@ +use crate::zir::folder::fold_statement; +use crate::zir::types::UBitwidth; use crate::zir::{ BooleanExpression, FieldElementExpression, Folder, UExpression, UExpressionInner, Variable, - ZirExpression, ZirFunction, ZirProgram, ZirStatement, + ZirExpression, ZirProgram, ZirStatement, }; use std::collections::HashMap; use zokrates_field::Field; type Constants<'ast, T> = HashMap, ZirExpression<'ast, T>>; -trait Propagator<'ast, T> { - type Output; - fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output; -} - #[derive(Default)] -pub struct ZirPropagator; +pub struct ZirPropagator<'ast, T> { + constants: Constants<'ast, T>, +} -impl ZirPropagator { +impl<'ast, T: Field> ZirPropagator<'ast, T> { pub fn new() -> Self { ZirPropagator::default() } - pub fn propagate(p: ZirProgram) -> ZirProgram { + pub fn propagate(p: ZirProgram) -> ZirProgram { ZirPropagator::new().fold_program(p) } } -impl<'ast, T: Field> Propagator<'ast, T> for ZirStatement<'ast, T> { - type Output = Option; - - fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { - match self { - ZirStatement::Assertion(e) => match e.propagate(constants) { - BooleanExpression::Value(true) => None, - e => Some(ZirStatement::Assertion(e)), +impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { + fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { + match s { + ZirStatement::Assertion(e) => match self.fold_boolean_expression(e) { + BooleanExpression::Value(true) => vec![], + e => vec![ZirStatement::Assertion(e)], }, ZirStatement::Definition(a, e) => { - let e = e.propagate(constants); + let e = self.fold_expression(e); match e { - ZirExpression::FieldElement(FieldElementExpression::Number(_)) - | ZirExpression::Boolean(BooleanExpression::Value(_)) => { - constants.insert(a, e); - None - } - ZirExpression::Uint(e) => match e.inner { - UExpressionInner::Value(_) => { - constants.insert(a, ZirExpression::Uint(e)); - None - } - _ => Some(ZirStatement::Definition(a, ZirExpression::Uint(e))), - }, - _ => Some(ZirStatement::Definition(a, e)), - } - } - ZirStatement::IfElse(e, consequence, alternative) => Some(ZirStatement::IfElse( - e.propagate(constants), - consequence - .into_iter() - .filter_map(|s| s.propagate(constants)) - .collect(), - alternative - .into_iter() - .filter_map(|s| s.propagate(constants)) - .collect(), - )), - ZirStatement::Return(e) => Some(ZirStatement::Return( - e.into_iter().map(|e| e.propagate(constants)).collect(), - )), - ZirStatement::MultipleDefinition(assignees, list) => { - // TODO: apply propagation here - Some(ZirStatement::MultipleDefinition(assignees, list)) - } - } - } -} - -impl<'ast, T: Field> Propagator<'ast, T> for ZirExpression<'ast, T> { - type Output = Self; - - fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { - match self { - ZirExpression::Boolean(e) => ZirExpression::Boolean(e.propagate(constants)), - ZirExpression::FieldElement(e) => ZirExpression::FieldElement(e.propagate(constants)), - ZirExpression::Uint(e) => ZirExpression::Uint(e.propagate(constants)), - } - } -} - -impl<'ast, T: Field> Propagator<'ast, T> for UExpression<'ast, T> { - type Output = Self; - - fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { - UExpression { - inner: match self.inner { - UExpressionInner::Value(v) => UExpressionInner::Value(v), - UExpressionInner::Identifier(id) => { - match constants.get(&Variable::uint(id.clone(), self.bitwidth)) { - Some(ZirExpression::Uint(e)) => match e.inner { - UExpressionInner::Value(v) => UExpressionInner::Value(v), - _ => unreachable!("should contain constant uint value"), - }, - _ => UExpressionInner::Identifier(id), - } - } - UExpressionInner::Select(e, box index) => { - let index = index.propagate(constants); - match index.inner { - UExpressionInner::Value(v) => e - .get(v as usize) - .cloned() - .expect("index out of bounds") - .into_inner(), - _ => UExpressionInner::Select(e, box index), - } - } - UExpressionInner::Add(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 + n2) - } - _ => UExpressionInner::Add(box e1, box e2), - } - } - UExpressionInner::Sub(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 - n2) - } - _ => UExpressionInner::Sub(box e1, box e2), + ZirExpression::FieldElement(FieldElementExpression::Number(..)) + | ZirExpression::Boolean(BooleanExpression::Value(..)) + | ZirExpression::Uint(UExpression { + inner: UExpressionInner::Value(..), + .. + }) => { + self.constants.insert(a, e); + vec![] } - } - UExpressionInner::Mult(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 * n2) - } - _ => UExpressionInner::Mult(box e1, box e2), - } - } - UExpressionInner::Div(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 / n2) - } - _ => UExpressionInner::Div(box e1, box e2), - } - } - UExpressionInner::Rem(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 % n2) - } - _ => UExpressionInner::Rem(box e1, box e2), - } - } - UExpressionInner::Xor(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 ^ n2) - } - _ => UExpressionInner::Xor(box e1, box e2), - } - } - UExpressionInner::And(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 & n2) - } - _ => UExpressionInner::And(box e1, box e2), - } - } - UExpressionInner::Or(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); - - match (&e1.inner, &e2.inner) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 | n2) - } - _ => UExpressionInner::Or(box e1, box e2), - } - } - UExpressionInner::LeftShift(box e, v) => { - let e = e.propagate(constants); - match &e.inner { - UExpressionInner::Value(n) => UExpressionInner::Value(n << v), - _ => UExpressionInner::LeftShift(box e, v), - } - } - UExpressionInner::RightShift(box e, v) => { - let e = e.propagate(constants); - match &e.inner { - UExpressionInner::Value(n) => UExpressionInner::Value(n >> v), - _ => UExpressionInner::RightShift(box e, v), - } - } - UExpressionInner::Not(box e) => { - let e = e.propagate(constants); - match &e.inner { - UExpressionInner::Value(n) => UExpressionInner::Value(!*n), - _ => UExpressionInner::Not(box e), + _ => { + self.constants.remove(&a); + vec![ZirStatement::Definition(a, e)] } } - UExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let condition = condition.propagate(constants); - match condition { - BooleanExpression::Value(true) => consequence.into_inner(), - BooleanExpression::Value(false) => alternative.into_inner(), - _ => UExpressionInner::IfElse( - box condition, - box consequence, - box alternative, - ), - } + } + ZirStatement::MultipleDefinition(assignees, list) => { + for a in &assignees { + self.constants.remove(a); } - }, - ..self + vec![ZirStatement::MultipleDefinition( + assignees, + self.fold_expression_list(list), + )] + } + _ => fold_statement(self, s), } } -} - -impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { - type Output = Self; - fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { - match self { + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + match e { FieldElementExpression::Number(n) => FieldElementExpression::Number(n), FieldElementExpression::Identifier(id) => { - match constants.get(&Variable::field_element(id.clone())) { + match self.constants.get(&Variable::field_element(id.clone())) { Some(ZirExpression::FieldElement(FieldElementExpression::Number(v))) => { FieldElementExpression::Number(v.clone()) } @@ -254,16 +76,30 @@ impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { } } FieldElementExpression::Select(e, box index) => { - let index = index.propagate(constants); - match index.inner { + let index = self.fold_uint_expression(index); + let e: Vec> = e + .into_iter() + .map(|e| self.fold_field_expression(e)) + .collect(); + + match index.into_inner() { UExpressionInner::Value(v) => { e.get(v as usize).cloned().expect("index out of bounds") } - _ => FieldElementExpression::Select(e, box index), + i => FieldElementExpression::Select(e, box i.annotate(UBitwidth::B32)), } } FieldElementExpression::Add(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { + (FieldElementExpression::Number(n), e) + | (e, FieldElementExpression::Number(n)) + if n == T::from(0) => + { + e + } (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 + n2) } @@ -271,7 +107,11 @@ impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { } } FieldElementExpression::Sub(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { + (e, FieldElementExpression::Number(n)) if n == T::from(0) => e, (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 - n2) } @@ -279,7 +119,22 @@ impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { } } FieldElementExpression::Mult(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { + (FieldElementExpression::Number(n), _) + | (_, FieldElementExpression::Number(n)) + if n == T::from(0) => + { + FieldElementExpression::Number(T::from(0)) + } + (FieldElementExpression::Number(n), e) + | (e, FieldElementExpression::Number(n)) + if n == T::from(1) => + { + e + } (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 * n2) } @@ -287,7 +142,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { } } FieldElementExpression::Div(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 / n2) } @@ -295,19 +153,25 @@ impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { } } FieldElementExpression::Pow(box e, box exponent) => { - let exponent = exponent.propagate(constants); - match (e.propagate(constants), &exponent.inner) { - (_, UExpressionInner::Value(n2)) if *n2 == 0 => { + let exponent = self.fold_uint_expression(exponent); + match (self.fold_field_expression(e), exponent.into_inner()) { + (_, UExpressionInner::Value(n2)) if n2 == 0 => { FieldElementExpression::Number(T::from(1)) } + (e, UExpressionInner::Value(n2)) if n2 == 1 => e, (FieldElementExpression::Number(n), UExpressionInner::Value(e)) => { - FieldElementExpression::Number(n.pow(*e as usize)) + FieldElementExpression::Number(n.pow(e as usize)) + } + (e, exp) => { + FieldElementExpression::Pow(box e, box exp.annotate(UBitwidth::B32)) } - (e, _) => FieldElementExpression::Pow(box e, box exponent), } } FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { - let condition = condition.propagate(constants); + let condition = self.fold_boolean_expression(condition); + let consequence = self.fold_field_expression(consequence); + let alternative = self.fold_field_expression(alternative); + match condition { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, @@ -320,16 +184,15 @@ impl<'ast, T: Field> Propagator<'ast, T> for FieldElementExpression<'ast, T> { } } } -} -impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { - type Output = Self; - - fn propagate(self, constants: &mut Constants<'ast, T>) -> Self::Output { - match self { + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + match e { BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => { - match constants.get(&Variable::boolean(id.clone())) { + match self.constants.get(&Variable::boolean(id.clone())) { Some(ZirExpression::Boolean(BooleanExpression::Value(v))) => { BooleanExpression::Value(*v) } @@ -337,16 +200,24 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::Select(e, box index) => { - let index = index.propagate(constants); - match index.inner { + let index = self.fold_uint_expression(index); + let e: Vec> = e + .into_iter() + .map(|e| self.fold_boolean_expression(e)) + .collect(); + + match index.as_inner() { UExpressionInner::Value(v) => { - e.get(v as usize).cloned().expect("index out of bounds") + e.get(*v as usize).cloned().expect("index out of bounds") } _ => BooleanExpression::Select(e, box index), } } BooleanExpression::FieldLt(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 < n2) } @@ -354,7 +225,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::FieldLe(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 <= n2) } @@ -362,7 +236,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::FieldGe(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 >= n2) } @@ -370,7 +247,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::FieldGt(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 > n2) } @@ -378,7 +258,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::FieldEq(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_field_expression(e1), + self.fold_field_expression(e2), + ) { (FieldElementExpression::Number(v1), FieldElementExpression::Number(v2)) => { BooleanExpression::Value(v1.eq(&v2)) } @@ -392,10 +275,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::UintLt(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); - match (&e1.inner, &e2.inner) { + match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { BooleanExpression::Value(v1 < v2) } @@ -403,10 +286,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::UintLe(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); - match (&e1.inner, &e2.inner) { + match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { BooleanExpression::Value(v1 <= v2) } @@ -414,10 +297,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::UintGe(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); - match (&e1.inner, &e2.inner) { + match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { BooleanExpression::Value(v1 >= v2) } @@ -425,10 +308,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::UintGt(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); - match (&e1.inner, &e2.inner) { + match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { BooleanExpression::Value(v1 > v2) } @@ -436,10 +319,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::UintEq(box e1, box e2) => { - let e1 = e1.propagate(constants); - let e2 = e2.propagate(constants); + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); - match (&e1.inner, &e2.inner) { + match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { BooleanExpression::Value(v1 == v2) } @@ -453,7 +336,10 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::BoolEq(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_boolean_expression(e1), + self.fold_boolean_expression(e2), + ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { BooleanExpression::Value(v1 == v2) } @@ -467,15 +353,27 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } BooleanExpression::Or(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_boolean_expression(e1), + self.fold_boolean_expression(e2), + ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { BooleanExpression::Value(v1 || v2) } + (_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => { + BooleanExpression::Value(true) + } + (e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => { + e + } (e1, e2) => BooleanExpression::Or(box e1, box e2), } } BooleanExpression::And(box e1, box e2) => { - match (e1.propagate(constants), e2.propagate(constants)) { + match ( + self.fold_boolean_expression(e1), + self.fold_boolean_expression(e2), + ) { (BooleanExpression::Value(true), e) | (e, BooleanExpression::Value(true)) => e, (BooleanExpression::Value(false), _) | (_, BooleanExpression::Value(false)) => { BooleanExpression::Value(false) @@ -483,12 +381,15 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { (e1, e2) => BooleanExpression::And(box e1, box e2), } } - BooleanExpression::Not(box e) => match e.propagate(constants) { + BooleanExpression::Not(box e) => match self.fold_boolean_expression(e) { BooleanExpression::Value(v) => BooleanExpression::Value(!v), e => BooleanExpression::Not(box e), }, BooleanExpression::IfElse(box condition, box consequence, box alternative) => { - let condition = condition.propagate(constants); + let condition = self.fold_boolean_expression(condition); + let consequence = self.fold_boolean_expression(consequence); + let alternative = self.fold_boolean_expression(alternative); + match condition { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, @@ -497,20 +398,188 @@ impl<'ast, T: Field> Propagator<'ast, T> for BooleanExpression<'ast, T> { } } } -} -impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator { - fn fold_function(&mut self, f: ZirFunction<'ast, T>) -> ZirFunction<'ast, T> { - let mut constants: HashMap, ZirExpression<'ast, T>> = HashMap::new(); + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + match e { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + UExpressionInner::Identifier(id) => { + match self.constants.get(&Variable::uint(id.clone(), bitwidth)) { + Some(ZirExpression::Uint(e)) => e.as_inner().clone(), + _ => UExpressionInner::Identifier(id), + } + } + UExpressionInner::Select(e, box index) => { + let index = self.fold_uint_expression(index); + let e: Vec> = e + .into_iter() + .map(|e| self.fold_uint_expression(e)) + .collect(); + + match index.into_inner() { + UExpressionInner::Value(v) => e + .get(v as usize) + .cloned() + .expect("index out of bounds") + .into_inner(), + i => UExpressionInner::Select(e, box i.annotate(bitwidth)), + } + } + UExpressionInner::Add(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (UExpressionInner::Value(0), e) | (e, UExpressionInner::Value(0)) => e, + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value((n1 + n2) % 2_u128.pow(bitwidth.to_usize() as u32)) + } + (e1, e2) => { + UExpressionInner::Add(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::Sub(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (e, UExpressionInner::Value(0)) => e, + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value( + n1.wrapping_sub(n2) % 2_u128.pow(bitwidth.to_usize() as u32), + ) + } + (e1, e2) => { + UExpressionInner::Sub(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::Mult(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { + UExpressionInner::Value(0) + } + (e, UExpressionInner::Value(1)) | (UExpressionInner::Value(1), e) => e, + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value((n1 * n2) % 2_u128.pow(bitwidth.to_usize() as u32)) + } + (e1, e2) => { + UExpressionInner::Mult(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::Div(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)) + } + (e1, e2) => { + UExpressionInner::Div(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::Rem(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value((n1 % n2) % 2_u128.pow(bitwidth.to_usize() as u32)) + } + (e1, e2) => { + UExpressionInner::Rem(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::Xor(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 ^ n2) + } + (e1, e2) => { + UExpressionInner::Xor(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::And(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); - ZirFunction { - signature: f.signature, - arguments: f.arguments, - statements: f - .statements - .into_iter() - .filter_map(|s| s.propagate(&mut constants)) - .collect(), + match (e1.into_inner(), e2.into_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 & n2) + } + (e1, e2) => { + UExpressionInner::And(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::Or(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1); + let e2 = self.fold_uint_expression(e2); + + match (e1.into_inner(), e2.into_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + UExpressionInner::Value(n1 | n2) + } + (e1, e2) => { + UExpressionInner::Or(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + } + } + UExpressionInner::LeftShift(box e, v) => { + let e = self.fold_uint_expression(e); + match e.into_inner() { + UExpressionInner::Value(n) => { + UExpressionInner::Value((n << v) & (2_u128.pow(bitwidth as u32) - 1)) + } + e => UExpressionInner::LeftShift(box e.annotate(bitwidth), v), + } + } + UExpressionInner::RightShift(box e, v) => { + let e = self.fold_uint_expression(e); + match e.into_inner() { + UExpressionInner::Value(n) => UExpressionInner::Value(n >> v), + e => UExpressionInner::RightShift(box e.annotate(bitwidth), v), + } + } + UExpressionInner::Not(box e) => { + let e = self.fold_uint_expression(e); + match e.into_inner() { + UExpressionInner::Value(n) => { + UExpressionInner::Value(!n & (2_u128.pow(bitwidth as u32) - 1)) + } + e => UExpressionInner::Not(box e.annotate(bitwidth)), + } + } + UExpressionInner::IfElse(box condition, box consequence, box alternative) => { + let condition = self.fold_boolean_expression(condition); + let consequence = self.fold_uint_expression(consequence).into_inner(); + let alternative = self.fold_uint_expression(alternative).into_inner(); + + match condition { + BooleanExpression::Value(true) => consequence, + BooleanExpression::Value(false) => alternative, + _ => UExpressionInner::IfElse( + box condition, + box consequence.annotate(bitwidth), + box alternative.annotate(bitwidth), + ), + } + } } } } From ca0fc111f295a52e30890fcb1b9a04584e816b60 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 12 Aug 2021 20:26:53 +0200 Subject: [PATCH 03/78] write tests for zir propagator --- .../src/static_analysis/zir_propagation.rs | 879 +++++++++++++++++- 1 file changed, 854 insertions(+), 25 deletions(-) diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 4ee085ebe..1a157181c 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -1,13 +1,13 @@ use crate::zir::folder::fold_statement; use crate::zir::types::UBitwidth; use crate::zir::{ - BooleanExpression, FieldElementExpression, Folder, UExpression, UExpressionInner, Variable, + BooleanExpression, FieldElementExpression, Folder, Identifier, UExpression, UExpressionInner, ZirExpression, ZirProgram, ZirStatement, }; use std::collections::HashMap; use zokrates_field::Field; -type Constants<'ast, T> = HashMap, ZirExpression<'ast, T>>; +type Constants<'ast, T> = HashMap, ZirExpression<'ast, T>>; #[derive(Default)] pub struct ZirPropagator<'ast, T> { @@ -39,18 +39,18 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { inner: UExpressionInner::Value(..), .. }) => { - self.constants.insert(a, e); + self.constants.insert(a.id, e); vec![] } _ => { - self.constants.remove(&a); + self.constants.remove(&a.id); vec![ZirStatement::Definition(a, e)] } } } ZirStatement::MultipleDefinition(assignees, list) => { for a in &assignees { - self.constants.remove(a); + self.constants.remove(&a.id); } vec![ZirStatement::MultipleDefinition( assignees, @@ -67,14 +67,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { ) -> FieldElementExpression<'ast, T> { match e { FieldElementExpression::Number(n) => FieldElementExpression::Number(n), - FieldElementExpression::Identifier(id) => { - match self.constants.get(&Variable::field_element(id.clone())) { - Some(ZirExpression::FieldElement(FieldElementExpression::Number(v))) => { - FieldElementExpression::Number(v.clone()) - } - _ => FieldElementExpression::Identifier(id), + FieldElementExpression::Identifier(id) => match self.constants.get(&id) { + Some(ZirExpression::FieldElement(FieldElementExpression::Number(v))) => { + FieldElementExpression::Number(v.clone()) } - } + _ => FieldElementExpression::Identifier(id), + }, FieldElementExpression::Select(e, box index) => { let index = self.fold_uint_expression(index); let e: Vec> = e @@ -191,14 +189,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { ) -> BooleanExpression<'ast, T> { match e { BooleanExpression::Value(v) => BooleanExpression::Value(v), - BooleanExpression::Identifier(id) => { - match self.constants.get(&Variable::boolean(id.clone())) { - Some(ZirExpression::Boolean(BooleanExpression::Value(v))) => { - BooleanExpression::Value(*v) - } - _ => BooleanExpression::Identifier(id), + BooleanExpression::Identifier(id) => match self.constants.get(&id) { + Some(ZirExpression::Boolean(BooleanExpression::Value(v))) => { + BooleanExpression::Value(*v) } - } + _ => BooleanExpression::Identifier(id), + }, BooleanExpression::Select(e, box index) => { let index = self.fold_uint_expression(index); let e: Vec> = e @@ -406,12 +402,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { ) -> UExpressionInner<'ast, T> { match e { UExpressionInner::Value(v) => UExpressionInner::Value(v), - UExpressionInner::Identifier(id) => { - match self.constants.get(&Variable::uint(id.clone(), bitwidth)) { - Some(ZirExpression::Uint(e)) => e.as_inner().clone(), - _ => UExpressionInner::Identifier(id), - } - } + UExpressionInner::Identifier(id) => match self.constants.get(&id) { + Some(ZirExpression::Uint(e)) => e.as_inner().clone(), + _ => UExpressionInner::Identifier(id), + }, UExpressionInner::Select(e, box index) => { let index = self.fold_uint_expression(index); let e: Vec> = e @@ -583,3 +577,838 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn propagation() { + // assert([x, 1] == [y, 1]) + let statements = vec![ZirStatement::Assertion(BooleanExpression::And( + box BooleanExpression::FieldEq( + box FieldElementExpression::Identifier("x".into()), + box FieldElementExpression::Identifier("y".into()), + ), + box BooleanExpression::FieldEq( + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(1)), + ), + ))]; + + let mut propagator = ZirPropagator::new(); + let statements: Vec> = statements + .into_iter() + .flat_map(|s| propagator.fold_statement(s)) + .collect(); + + assert_eq!( + statements, + vec![ZirStatement::Assertion(BooleanExpression::FieldEq( + box FieldElementExpression::Identifier("x".into()), + box FieldElementExpression::Identifier("y".into()), + ))] + ); + } + + #[cfg(test)] + mod field { + use super::*; + + #[test] + fn select() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Select( + vec![ + FieldElementExpression::Number(Bn128Field::from(1)), + FieldElementExpression::Number(Bn128Field::from(2)), + ], + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + FieldElementExpression::Number(Bn128Field::from(2)) + ); + } + + #[test] + fn add() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Add( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + FieldElementExpression::Number(Bn128Field::from(5)) + ); + + // a + 0 = a + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Add( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(0)), + )), + FieldElementExpression::Identifier("a".into()) + ); + } + + #[test] + fn sub() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Sub( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + FieldElementExpression::Number(Bn128Field::from(1)) + ); + + // a - 0 = a + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Sub( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(0)), + )), + FieldElementExpression::Identifier("a".into()) + ); + } + + #[test] + fn mult() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Mult( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + FieldElementExpression::Number(Bn128Field::from(6)) + ); + + // a * 0 = 0 + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Mult( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(0)), + )), + FieldElementExpression::Number(Bn128Field::from(0)) + ); + + // a * 1 = a + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Mult( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(1)), + )), + FieldElementExpression::Identifier("a".into()) + ); + } + + #[test] + fn div() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Div( + box FieldElementExpression::Number(Bn128Field::from(6)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + FieldElementExpression::Number(Bn128Field::from(3)) + ); + } + + #[test] + fn pow() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Pow( + box FieldElementExpression::Number(Bn128Field::from(3)), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + )), + FieldElementExpression::Number(Bn128Field::from(9)) + ); + + // a ** 0 = 1 + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Pow( + box FieldElementExpression::Identifier("a".into()), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + FieldElementExpression::Number(Bn128Field::from(1)) + ); + + // a ** 1 = a + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Pow( + box FieldElementExpression::Identifier("a".into()), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + FieldElementExpression::Identifier("a".into()) + ); + } + + #[test] + fn if_else() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::IfElse( + box BooleanExpression::Value(true), + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + FieldElementExpression::Number(Bn128Field::from(1)) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::IfElse( + box BooleanExpression::Value(false), + box FieldElementExpression::Number(Bn128Field::from(1)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + FieldElementExpression::Number(Bn128Field::from(2)) + ); + } + } + + #[cfg(test)] + mod bool { + use super::*; + + #[test] + fn select() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::Select( + vec![ + BooleanExpression::Value(false), + BooleanExpression::Value(true), + ], + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn field_lt() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldLt( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldLt( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn field_le() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldLe( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldLe( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn field_ge() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldGe( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldGe( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn field_gt() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldGt( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldGt( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(3)), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn field_eq() { + let mut propagator = ZirPropagator::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldEq( + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::FieldEq( + box FieldElementExpression::Number(Bn128Field::from(3)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn uint_lt() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintLt( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintLt( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn uint_le() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintLe( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintLe( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn uint_ge() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintGe( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintGe( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn uint_gt() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintGt( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintGt( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn uint_eq() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintEq( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::UintEq( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn bool_eq() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::BoolEq( + box BooleanExpression::Value(true), + box BooleanExpression::Value(true), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::BoolEq( + box BooleanExpression::Value(true), + box BooleanExpression::Value(false), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn and() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::And( + box BooleanExpression::Value(true), + box BooleanExpression::Value(true), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::And( + box BooleanExpression::Value(true), + box BooleanExpression::Value(false), + )), + BooleanExpression::Value(false) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::And( + box BooleanExpression::Identifier("a".into()), + box BooleanExpression::Value(true), + )), + BooleanExpression::Identifier("a".into()) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::And( + box BooleanExpression::Identifier("a".into()), + box BooleanExpression::Value(false), + )), + BooleanExpression::Value(false) + ); + } + + #[test] + fn or() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(true), + box BooleanExpression::Value(true), + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(true), + box BooleanExpression::Value(false), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn not() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::Not( + box BooleanExpression::Value(true), + )), + BooleanExpression::Value(false) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::Not( + box BooleanExpression::Value(false), + )), + BooleanExpression::Value(true) + ); + } + + #[test] + fn if_else() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::IfElse( + box BooleanExpression::Value(true), + box BooleanExpression::Value(true), + box BooleanExpression::Value(false) + )), + BooleanExpression::Value(true) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::IfElse( + box BooleanExpression::Value(false), + box BooleanExpression::Value(true), + box BooleanExpression::Value(false) + )), + BooleanExpression::Value(false) + ); + } + } + + #[cfg(test)] + mod uint { + use super::*; + + #[test] + fn select() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Select( + vec![ + UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpressionInner::Value(2).annotate(UBitwidth::B32), + ], + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(2) + ); + } + + #[test] + fn add() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Add( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(5) + ); + + // a + 0 = a + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Add( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Identifier("a".into()) + ); + } + + #[test] + fn sub() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Sub( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(1) + ); + + // a - 0 = a + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Sub( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Identifier("a".into()) + ); + } + + #[test] + fn mult() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Mult( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(6) + ); + + // a * 1 = a + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Mult( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Identifier("a".into()) + ); + + // a * 0 = 0 + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Mult( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(0) + ); + } + + #[test] + fn div() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Div( + box UExpressionInner::Value(6).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(3) + ); + } + + #[test] + fn rem() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Rem( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(2) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Rem( + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(1) + ); + } + + #[test] + fn xor() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Xor( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(1) + ); + } + + #[test] + fn and() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::And( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(2) + ); + } + + #[test] + fn or() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Or( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(3) + ); + } + + #[test] + fn left_shift() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::LeftShift( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + 3, + ) + ), + UExpressionInner::Value(16) + ); + } + + #[test] + fn right_shift() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::RightShift( + box UExpressionInner::Value(4).annotate(UBitwidth::B32), + 2, + ) + ), + UExpressionInner::Value(1) + ); + } + + #[test] + fn not() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Not(box UExpressionInner::Value(2).annotate(UBitwidth::B32),) + ), + UExpressionInner::Value(4294967293) + ); + } + + #[test] + fn if_else() { + let mut propagator = ZirPropagator::::new(); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::IfElse( + box BooleanExpression::Value(true), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(1) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::IfElse( + box BooleanExpression::Value(false), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(2) + ); + } + } +} From bf948dd3b6596f98d875225349f115d1bb5771e5 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 16 Aug 2021 20:48:42 +0200 Subject: [PATCH 04/78] add more optimizations --- zokrates_core/src/flatten/mod.rs | 1 - .../src/static_analysis/zir_propagation.rs | 276 ++++++++++++++---- 2 files changed, 226 insertions(+), 51 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index d9423b863..126190b8e 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1447,7 +1447,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { // whether this value should be reduced, for example if it is then used in a bitwidth operation let should_reduce = metadata.should_reduce; - assert!(!should_reduce.is_unknown(), "should_reduce should be known"); let should_reduce = should_reduce.to_bool(); diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 1a157181c..4d82ef899 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -15,11 +15,8 @@ pub struct ZirPropagator<'ast, T> { } impl<'ast, T: Field> ZirPropagator<'ast, T> { - pub fn new() -> Self { - ZirPropagator::default() - } pub fn propagate(p: ZirProgram) -> ZirProgram { - ZirPropagator::new().fold_program(p) + ZirPropagator::default().fold_program(p) } } @@ -144,6 +141,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { self.fold_field_expression(e1), self.fold_field_expression(e2), ) { + (e, FieldElementExpression::Number(n)) if n == T::from(1) => e, (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 / n2) } @@ -170,6 +168,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let consequence = self.fold_field_expression(consequence); let alternative = self.fold_field_expression(alternative); + if consequence.eq(&alternative) { + return consequence; + } + match condition { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, @@ -386,6 +388,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let consequence = self.fold_boolean_expression(consequence); let alternative = self.fold_boolean_expression(alternative); + if consequence.eq(&alternative) { + return consequence; + } + match condition { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, @@ -474,6 +480,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let e2 = self.fold_uint_expression(e2); match (e1.into_inner(), e2.into_inner()) { + (e, UExpressionInner::Value(n)) if n == 1 => e, (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)) } @@ -503,6 +510,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { UExpressionInner::Value(n1 ^ n2) } + (e1, e2) if e1.eq(&e2) => UExpressionInner::Value(0), (e1, e2) => { UExpressionInner::Xor(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) } @@ -513,6 +521,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let e2 = self.fold_uint_expression(e2); match (e1.into_inner(), e2.into_inner()) { + (e, UExpressionInner::Value(n)) + if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => + { + e + } + (_, UExpressionInner::Value(0)) => UExpressionInner::Value(0), (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { UExpressionInner::Value(n1 & n2) } @@ -526,6 +540,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let e2 = self.fold_uint_expression(e2); match (e1.into_inner(), e2.into_inner()) { + (e, UExpressionInner::Value(0)) => e, + (_, UExpressionInner::Value(n)) + if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => + { + UExpressionInner::Value(n) + } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { UExpressionInner::Value(n1 | n2) } @@ -534,20 +554,24 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { } } } - UExpressionInner::LeftShift(box e, v) => { + UExpressionInner::LeftShift(box e, by) => { let e = self.fold_uint_expression(e); - match e.into_inner() { - UExpressionInner::Value(n) => { - UExpressionInner::Value((n << v) & (2_u128.pow(bitwidth as u32) - 1)) + match (e.into_inner(), by) { + (e, 0) => e, + (_, by) if by >= bitwidth as u32 => UExpressionInner::Value(0), + (UExpressionInner::Value(n), by) => { + UExpressionInner::Value((n << by) & (2_u128.pow(bitwidth as u32) - 1)) } - e => UExpressionInner::LeftShift(box e.annotate(bitwidth), v), + (e, by) => UExpressionInner::LeftShift(box e.annotate(bitwidth), by), } } - UExpressionInner::RightShift(box e, v) => { + UExpressionInner::RightShift(box e, by) => { let e = self.fold_uint_expression(e); - match e.into_inner() { - UExpressionInner::Value(n) => UExpressionInner::Value(n >> v), - e => UExpressionInner::RightShift(box e.annotate(bitwidth), v), + match (e.into_inner(), by) { + (e, 0) => e, + (_, by) if by >= bitwidth as u32 => UExpressionInner::Value(0), + (UExpressionInner::Value(n), by) => UExpressionInner::Value(n >> by), + (e, by) => UExpressionInner::RightShift(box e.annotate(bitwidth), by), } } UExpressionInner::Not(box e) => { @@ -564,6 +588,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let consequence = self.fold_uint_expression(consequence).into_inner(); let alternative = self.fold_uint_expression(alternative).into_inner(); + if consequence.eq(&alternative) { + return consequence; + } + match condition { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, @@ -597,7 +625,7 @@ mod tests { ), ))]; - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); let statements: Vec> = statements .into_iter() .flat_map(|s| propagator.fold_statement(s)) @@ -618,7 +646,7 @@ mod tests { #[test] fn select() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::Select( @@ -634,7 +662,7 @@ mod tests { #[test] fn add() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::Add( @@ -656,7 +684,7 @@ mod tests { #[test] fn sub() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::Sub( @@ -678,7 +706,7 @@ mod tests { #[test] fn mult() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::Mult( @@ -709,7 +737,7 @@ mod tests { #[test] fn div() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::Div( @@ -718,11 +746,19 @@ mod tests { )), FieldElementExpression::Number(Bn128Field::from(3)) ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Div( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(1)), + )), + FieldElementExpression::Identifier("a".into()) + ); } #[test] fn pow() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::Pow( @@ -753,7 +789,7 @@ mod tests { #[test] fn if_else() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::IfElse( @@ -772,6 +808,15 @@ mod tests { )), FieldElementExpression::Number(Bn128Field::from(2)) ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::IfElse( + box BooleanExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(2)), + box FieldElementExpression::Number(Bn128Field::from(2)), + )), + FieldElementExpression::Number(Bn128Field::from(2)) + ); } } @@ -781,7 +826,7 @@ mod tests { #[test] fn select() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::Select( @@ -797,7 +842,7 @@ mod tests { #[test] fn field_lt() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::FieldLt( @@ -818,7 +863,7 @@ mod tests { #[test] fn field_le() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::FieldLe( @@ -839,7 +884,7 @@ mod tests { #[test] fn field_ge() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::FieldGe( @@ -860,7 +905,7 @@ mod tests { #[test] fn field_gt() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::FieldGt( @@ -881,7 +926,7 @@ mod tests { #[test] fn field_eq() { - let mut propagator = ZirPropagator::new(); + let mut propagator = ZirPropagator::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::FieldEq( @@ -902,7 +947,7 @@ mod tests { #[test] fn uint_lt() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::UintLt( @@ -923,7 +968,7 @@ mod tests { #[test] fn uint_le() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::UintLe( @@ -944,7 +989,7 @@ mod tests { #[test] fn uint_ge() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::UintGe( @@ -965,7 +1010,7 @@ mod tests { #[test] fn uint_gt() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::UintGt( @@ -986,7 +1031,7 @@ mod tests { #[test] fn uint_eq() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::UintEq( @@ -1007,7 +1052,7 @@ mod tests { #[test] fn bool_eq() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::BoolEq( @@ -1028,7 +1073,7 @@ mod tests { #[test] fn and() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::And( @@ -1065,7 +1110,7 @@ mod tests { #[test] fn or() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::Or( @@ -1086,7 +1131,7 @@ mod tests { #[test] fn not() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::Not( @@ -1105,7 +1150,7 @@ mod tests { #[test] fn if_else() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::IfElse( @@ -1124,6 +1169,15 @@ mod tests { )), BooleanExpression::Value(false) ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::IfElse( + box BooleanExpression::Identifier("a".into()), + box BooleanExpression::Value(true), + box BooleanExpression::Value(true) + )), + BooleanExpression::Value(true) + ); } } @@ -1133,7 +1187,7 @@ mod tests { #[test] fn select() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1152,7 +1206,7 @@ mod tests { #[test] fn add() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1180,7 +1234,7 @@ mod tests { #[test] fn sub() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1208,7 +1262,7 @@ mod tests { #[test] fn mult() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1248,7 +1302,7 @@ mod tests { #[test] fn div() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1260,11 +1314,22 @@ mod tests { ), UExpressionInner::Value(3) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Div( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(1).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Identifier("a".into()) + ); } #[test] fn rem() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1291,7 +1356,7 @@ mod tests { #[test] fn xor() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1303,11 +1368,22 @@ mod tests { ), UExpressionInner::Value(1) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Xor( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(0) + ); } #[test] fn and() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1319,11 +1395,33 @@ mod tests { ), UExpressionInner::Value(2) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::And( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(0) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::And( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Identifier("a".into()) + ); } #[test] fn or() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1335,11 +1433,33 @@ mod tests { ), UExpressionInner::Value(3) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Or( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Identifier("a".into()) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Or( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(u32::MAX as u128) + ); } #[test] fn left_shift() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1351,11 +1471,33 @@ mod tests { ), UExpressionInner::Value(16) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::LeftShift( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + 0, + ) + ), + UExpressionInner::Value(2) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::LeftShift( + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + 32, + ) + ), + UExpressionInner::Value(0) + ); } #[test] fn right_shift() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1367,11 +1509,33 @@ mod tests { ), UExpressionInner::Value(1) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::RightShift( + box UExpressionInner::Value(4).annotate(UBitwidth::B32), + 0, + ) + ), + UExpressionInner::Value(4) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::RightShift( + box UExpressionInner::Value(4).annotate(UBitwidth::B32), + 32, + ) + ), + UExpressionInner::Value(0) + ); } #[test] fn not() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1384,7 +1548,7 @@ mod tests { #[test] fn if_else() { - let mut propagator = ZirPropagator::::new(); + let mut propagator = ZirPropagator::::default(); assert_eq!( propagator.fold_uint_expression_inner( @@ -1409,6 +1573,18 @@ mod tests { ), UExpressionInner::Value(2) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::IfElse( + box BooleanExpression::Identifier("a".into()), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + box UExpressionInner::Value(2).annotate(UBitwidth::B32), + ) + ), + UExpressionInner::Value(2) + ); } } } From 81751c2ca4f5ea0030eda227d21a2c4de3e0111c Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 17 Aug 2021 17:40:35 +0200 Subject: [PATCH 05/78] add symmetric cases --- .../src/static_analysis/zir_propagation.rs | 51 +++++++++---------- zokrates_core/src/zir/result_folder.rs | 0 2 files changed, 23 insertions(+), 28 deletions(-) create mode 100644 zokrates_core/src/zir/result_folder.rs diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 4d82ef899..667669873 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -168,14 +168,11 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let consequence = self.fold_field_expression(consequence); let alternative = self.fold_field_expression(alternative); - if consequence.eq(&alternative) { - return consequence; - } - - match condition { - BooleanExpression::Value(true) => consequence, - BooleanExpression::Value(false) => alternative, - _ => FieldElementExpression::IfElse( + match (condition, consequence, alternative) { + (_, consequence, alternative) if consequence == alternative => consequence, + (BooleanExpression::Value(true), consequence, _) => consequence, + (BooleanExpression::Value(false), _, alternative) => alternative, + (condition, consequence, alternative) => FieldElementExpression::IfElse( box condition, box consequence, box alternative, @@ -388,14 +385,13 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let consequence = self.fold_boolean_expression(consequence); let alternative = self.fold_boolean_expression(alternative); - if consequence.eq(&alternative) { - return consequence; - } - - match condition { - BooleanExpression::Value(true) => consequence, - BooleanExpression::Value(false) => alternative, - _ => BooleanExpression::IfElse(box condition, box consequence, box alternative), + match (condition, consequence, alternative) { + (_, consequence, alternative) if consequence == alternative => consequence, + (BooleanExpression::Value(true), consequence, _) => consequence, + (BooleanExpression::Value(false), _, alternative) => alternative, + (condition, consequence, alternative) => { + BooleanExpression::IfElse(box condition, box consequence, box alternative) + } } } } @@ -521,12 +517,14 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let e2 = self.fold_uint_expression(e2); match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(n)) + (e, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), e) if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { e } - (_, UExpressionInner::Value(0)) => UExpressionInner::Value(0), + (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { + UExpressionInner::Value(0) + } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { UExpressionInner::Value(n1 & n2) } @@ -540,8 +538,8 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let e2 = self.fold_uint_expression(e2); match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) => e, - (_, UExpressionInner::Value(n)) + (e, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), e) => e, + (_, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), _) if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { UExpressionInner::Value(n) @@ -588,14 +586,11 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { let consequence = self.fold_uint_expression(consequence).into_inner(); let alternative = self.fold_uint_expression(alternative).into_inner(); - if consequence.eq(&alternative) { - return consequence; - } - - match condition { - BooleanExpression::Value(true) => consequence, - BooleanExpression::Value(false) => alternative, - _ => UExpressionInner::IfElse( + match (condition, consequence, alternative) { + (_, consequence, alternative) if consequence == alternative => consequence, + (BooleanExpression::Value(true), consequence, _) => consequence, + (BooleanExpression::Value(false), _, alternative) => alternative, + (condition, consequence, alternative) => UExpressionInner::IfElse( box condition, box consequence.annotate(bitwidth), box alternative.annotate(bitwidth), diff --git a/zokrates_core/src/zir/result_folder.rs b/zokrates_core/src/zir/result_folder.rs new file mode 100644 index 000000000..e69de29bb From 934d36078e6acb0f18aca85d97f58616ce74e141 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 17 Aug 2021 18:31:52 +0200 Subject: [PATCH 06/78] use result folder in zir propagation --- .../examples/compile_errors/out_of_bounds.zok | 4 + zokrates_core/src/static_analysis/mod.rs | 10 +- .../src/static_analysis/zir_propagation.rs | 728 ++++++++++-------- zokrates_core/src/zir/mod.rs | 1 + zokrates_core/src/zir/result_folder.rs | 418 ++++++++++ 5 files changed, 837 insertions(+), 324 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/out_of_bounds.zok diff --git a/zokrates_cli/examples/compile_errors/out_of_bounds.zok b/zokrates_cli/examples/compile_errors/out_of_bounds.zok new file mode 100644 index 000000000..f2922a4bc --- /dev/null +++ b/zokrates_cli/examples/compile_errors/out_of_bounds.zok @@ -0,0 +1,4 @@ +def main() -> field: + field[10] a = [0; 10] + u32 index = if [0f] != [1f] then 1000 else 0 fi + return a[index] \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 0ab0c296b..cc3ad4f2c 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -41,6 +41,7 @@ pub trait Analyse { pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), + ZirPropagation(self::zir_propagation::Error), NonConstantShift(self::shift_checker::Error), } @@ -56,6 +57,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: zir_propagation::Error) -> Self { + Error::ZirPropagation(e) + } +} + impl From for Error { fn from(e: shift_checker::Error) -> Self { Error::NonConstantShift(e) @@ -67,6 +74,7 @@ impl fmt::Display for Error { match self { Error::Reducer(e) => write!(f, "{}", e), Error::Propagation(e) => write!(f, "{}", e), + Error::ZirPropagation(e) => write!(f, "{}", e), Error::NonConstantShift(e) => write!(f, "{}", e), } } @@ -121,7 +129,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { // apply propagation in zir log::debug!("Static analyser: Apply propagation in zir"); - let zir = ZirPropagator::propagate(zir); + let zir = ZirPropagator::propagate(zir).map_err(Error::from)?; log::trace!("\n{}", zir); // optimize uint expressions diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 667669873..220e6916f 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -1,34 +1,58 @@ -use crate::zir::folder::fold_statement; +use crate::zir::result_folder::fold_statement; +use crate::zir::result_folder::ResultFolder; use crate::zir::types::UBitwidth; use crate::zir::{ - BooleanExpression, FieldElementExpression, Folder, Identifier, UExpression, UExpressionInner, + BooleanExpression, FieldElementExpression, Identifier, UExpression, UExpressionInner, ZirExpression, ZirProgram, ZirStatement, }; use std::collections::HashMap; +use std::fmt; use zokrates_field::Field; type Constants<'ast, T> = HashMap, ZirExpression<'ast, T>>; +#[derive(Debug, PartialEq)] +pub enum Error { + OutOfBounds(u128, u128), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::OutOfBounds(index, size) => write!( + f, + "Out of bounds index ({} >= {}) found in zir during static analysis", + index, size + ), + } + } +} + #[derive(Default)] pub struct ZirPropagator<'ast, T> { constants: Constants<'ast, T>, } impl<'ast, T: Field> ZirPropagator<'ast, T> { - pub fn propagate(p: ZirProgram) -> ZirProgram { + pub fn propagate(p: ZirProgram) -> Result, Error> { ZirPropagator::default().fold_program(p) } } -impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { - fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec> { +impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { + type Error = Error; + + fn fold_statement( + &mut self, + s: ZirStatement<'ast, T>, + ) -> Result>, Self::Error> { match s { - ZirStatement::Assertion(e) => match self.fold_boolean_expression(e) { - BooleanExpression::Value(true) => vec![], - e => vec![ZirStatement::Assertion(e)], + ZirStatement::Assertion(e) => match self.fold_boolean_expression(e)? { + BooleanExpression::Value(true) => Ok(vec![]), + e => Ok(vec![ZirStatement::Assertion(e)]), }, ZirStatement::Definition(a, e) => { - let e = self.fold_expression(e); + let e = self.fold_expression(e)?; match e { ZirExpression::FieldElement(FieldElementExpression::Number(..)) | ZirExpression::Boolean(BooleanExpression::Value(..)) @@ -37,11 +61,11 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { .. }) => { self.constants.insert(a.id, e); - vec![] + Ok(vec![]) } _ => { self.constants.remove(&a.id); - vec![ZirStatement::Definition(a, e)] + Ok(vec![ZirStatement::Definition(a, e)]) } } } @@ -49,10 +73,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { for a in &assignees { self.constants.remove(&a.id); } - vec![ZirStatement::MultipleDefinition( + Ok(vec![ZirStatement::MultipleDefinition( assignees, - self.fold_expression_list(list), - )] + self.fold_expression_list(list)?, + )]) } _ => fold_statement(self, s), } @@ -61,122 +85,127 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { fn fold_field_expression( &mut self, e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { + ) -> Result, Self::Error> { match e { - FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Number(n) => Ok(FieldElementExpression::Number(n)), FieldElementExpression::Identifier(id) => match self.constants.get(&id) { Some(ZirExpression::FieldElement(FieldElementExpression::Number(v))) => { - FieldElementExpression::Number(v.clone()) + Ok(FieldElementExpression::Number(v.clone())) } - _ => FieldElementExpression::Identifier(id), + _ => Ok(FieldElementExpression::Identifier(id)), }, FieldElementExpression::Select(e, box index) => { - let index = self.fold_uint_expression(index); + let index = self.fold_uint_expression(index)?; let e: Vec> = e .into_iter() .map(|e| self.fold_field_expression(e)) - .collect(); + .collect::>()?; match index.into_inner() { - UExpressionInner::Value(v) => { - e.get(v as usize).cloned().expect("index out of bounds") - } - i => FieldElementExpression::Select(e, box i.annotate(UBitwidth::B32)), + UExpressionInner::Value(v) => e + .get(v as usize) + .cloned() + .ok_or(Error::OutOfBounds(v, e.len() as u128)), + i => Ok(FieldElementExpression::Select( + e, + box i.annotate(UBitwidth::B32), + )), } } FieldElementExpression::Add(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n), e) | (e, FieldElementExpression::Number(n)) if n == T::from(0) => { - e + Ok(e) } (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - FieldElementExpression::Number(n1 + n2) + Ok(FieldElementExpression::Number(n1 + n2)) } - (e1, e2) => FieldElementExpression::Add(box e1, box e2), + (e1, e2) => Ok(FieldElementExpression::Add(box e1, box e2)), } } FieldElementExpression::Sub(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { - (e, FieldElementExpression::Number(n)) if n == T::from(0) => e, + (e, FieldElementExpression::Number(n)) if n == T::from(0) => Ok(e), (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - FieldElementExpression::Number(n1 - n2) + Ok(FieldElementExpression::Number(n1 - n2)) } - (e1, e2) => FieldElementExpression::Sub(box e1, box e2), + (e1, e2) => Ok(FieldElementExpression::Sub(box e1, box e2)), } } FieldElementExpression::Mult(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n), _) | (_, FieldElementExpression::Number(n)) if n == T::from(0) => { - FieldElementExpression::Number(T::from(0)) + Ok(FieldElementExpression::Number(T::from(0))) } (FieldElementExpression::Number(n), e) | (e, FieldElementExpression::Number(n)) if n == T::from(1) => { - e + Ok(e) } (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - FieldElementExpression::Number(n1 * n2) + Ok(FieldElementExpression::Number(n1 * n2)) } - (e1, e2) => FieldElementExpression::Mult(box e1, box e2), + (e1, e2) => Ok(FieldElementExpression::Mult(box e1, box e2)), } } FieldElementExpression::Div(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { - (e, FieldElementExpression::Number(n)) if n == T::from(1) => e, + (e, FieldElementExpression::Number(n)) if n == T::from(1) => Ok(e), (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - FieldElementExpression::Number(n1 / n2) + Ok(FieldElementExpression::Number(n1 / n2)) } - (e1, e2) => FieldElementExpression::Div(box e1, box e2), + (e1, e2) => Ok(FieldElementExpression::Div(box e1, box e2)), } } FieldElementExpression::Pow(box e, box exponent) => { - let exponent = self.fold_uint_expression(exponent); - match (self.fold_field_expression(e), exponent.into_inner()) { + let exponent = self.fold_uint_expression(exponent)?; + match (self.fold_field_expression(e)?, exponent.into_inner()) { (_, UExpressionInner::Value(n2)) if n2 == 0 => { - FieldElementExpression::Number(T::from(1)) + Ok(FieldElementExpression::Number(T::from(1))) } - (e, UExpressionInner::Value(n2)) if n2 == 1 => e, + (e, UExpressionInner::Value(n2)) if n2 == 1 => Ok(e), (FieldElementExpression::Number(n), UExpressionInner::Value(e)) => { - FieldElementExpression::Number(n.pow(e as usize)) - } - (e, exp) => { - FieldElementExpression::Pow(box e, box exp.annotate(UBitwidth::B32)) + Ok(FieldElementExpression::Number(n.pow(e as usize))) } + (e, exp) => Ok(FieldElementExpression::Pow( + box e, + box exp.annotate(UBitwidth::B32), + )), } } FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { - let condition = self.fold_boolean_expression(condition); - let consequence = self.fold_field_expression(consequence); - let alternative = self.fold_field_expression(alternative); + let condition = self.fold_boolean_expression(condition)?; + let consequence = self.fold_field_expression(consequence)?; + let alternative = self.fold_field_expression(alternative)?; match (condition, consequence, alternative) { - (_, consequence, alternative) if consequence == alternative => consequence, - (BooleanExpression::Value(true), consequence, _) => consequence, - (BooleanExpression::Value(false), _, alternative) => alternative, - (condition, consequence, alternative) => FieldElementExpression::IfElse( + (_, consequence, alternative) if consequence == alternative => Ok(consequence), + (BooleanExpression::Value(true), consequence, _) => Ok(consequence), + (BooleanExpression::Value(false), _, alternative) => Ok(alternative), + (condition, consequence, alternative) => Ok(FieldElementExpression::IfElse( box condition, box consequence, box alternative, - ), + )), } } } @@ -185,213 +214,218 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { + ) -> Result, Error> { match e { - BooleanExpression::Value(v) => BooleanExpression::Value(v), + BooleanExpression::Value(v) => Ok(BooleanExpression::Value(v)), BooleanExpression::Identifier(id) => match self.constants.get(&id) { Some(ZirExpression::Boolean(BooleanExpression::Value(v))) => { - BooleanExpression::Value(*v) + Ok(BooleanExpression::Value(*v)) } - _ => BooleanExpression::Identifier(id), + _ => Ok(BooleanExpression::Identifier(id)), }, BooleanExpression::Select(e, box index) => { - let index = self.fold_uint_expression(index); + let index = self.fold_uint_expression(index)?; let e: Vec> = e .into_iter() .map(|e| self.fold_boolean_expression(e)) - .collect(); + .collect::>()?; match index.as_inner() { - UExpressionInner::Value(v) => { - e.get(*v as usize).cloned().expect("index out of bounds") - } - _ => BooleanExpression::Select(e, box index), + UExpressionInner::Value(v) => e + .get(*v as usize) + .cloned() + .ok_or(Error::OutOfBounds(*v, e.len() as u128)), + _ => Ok(BooleanExpression::Select(e, box index)), } } BooleanExpression::FieldLt(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - BooleanExpression::Value(n1 < n2) + Ok(BooleanExpression::Value(n1 < n2)) } - (e1, e2) => BooleanExpression::FieldLt(box e1, box e2), + (e1, e2) => Ok(BooleanExpression::FieldLt(box e1, box e2)), } } BooleanExpression::FieldLe(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - BooleanExpression::Value(n1 <= n2) + Ok(BooleanExpression::Value(n1 <= n2)) } - (e1, e2) => BooleanExpression::FieldLe(box e1, box e2), + (e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)), } } BooleanExpression::FieldGe(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - BooleanExpression::Value(n1 >= n2) + Ok(BooleanExpression::Value(n1 >= n2)) } - (e1, e2) => BooleanExpression::FieldGe(box e1, box e2), + (e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)), } } BooleanExpression::FieldGt(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - BooleanExpression::Value(n1 > n2) + Ok(BooleanExpression::Value(n1 > n2)) } - (e1, e2) => BooleanExpression::FieldGt(box e1, box e2), + (e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)), } } BooleanExpression::FieldEq(box e1, box e2) => { match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(v1), FieldElementExpression::Number(v2)) => { - BooleanExpression::Value(v1.eq(&v2)) + Ok(BooleanExpression::Value(v1.eq(&v2))) } (e1, e2) => { if e1.eq(&e2) { - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) } else { - BooleanExpression::FieldEq(box e1, box e2) + Ok(BooleanExpression::FieldEq(box e1, box e2)) } } } } BooleanExpression::UintLt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - BooleanExpression::Value(v1 < v2) + Ok(BooleanExpression::Value(v1 < v2)) } - _ => BooleanExpression::UintLt(box e1, box e2), + _ => Ok(BooleanExpression::UintLt(box e1, box e2)), } } BooleanExpression::UintLe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - BooleanExpression::Value(v1 <= v2) + Ok(BooleanExpression::Value(v1 <= v2)) } - _ => BooleanExpression::UintLe(box e1, box e2), + _ => Ok(BooleanExpression::UintLe(box e1, box e2)), } } BooleanExpression::UintGe(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - BooleanExpression::Value(v1 >= v2) + Ok(BooleanExpression::Value(v1 >= v2)) } - _ => BooleanExpression::UintGe(box e1, box e2), + _ => Ok(BooleanExpression::UintGe(box e1, box e2)), } } BooleanExpression::UintGt(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - BooleanExpression::Value(v1 > v2) + Ok(BooleanExpression::Value(v1 > v2)) } - _ => BooleanExpression::UintGt(box e1, box e2), + _ => Ok(BooleanExpression::UintGt(box e1, box e2)), } } BooleanExpression::UintEq(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - BooleanExpression::Value(v1 == v2) + Ok(BooleanExpression::Value(v1 == v2)) } _ => { if e1.eq(&e2) { - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) } else { - BooleanExpression::UintEq(box e1, box e2) + Ok(BooleanExpression::UintEq(box e1, box e2)) } } } } BooleanExpression::BoolEq(box e1, box e2) => { match ( - self.fold_boolean_expression(e1), - self.fold_boolean_expression(e2), + self.fold_boolean_expression(e1)?, + self.fold_boolean_expression(e2)?, ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - BooleanExpression::Value(v1 == v2) + Ok(BooleanExpression::Value(v1 == v2)) } (e1, e2) => { if e1.eq(&e2) { - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) } else { - BooleanExpression::BoolEq(box e1, box e2) + Ok(BooleanExpression::BoolEq(box e1, box e2)) } } } } BooleanExpression::Or(box e1, box e2) => { match ( - self.fold_boolean_expression(e1), - self.fold_boolean_expression(e2), + self.fold_boolean_expression(e1)?, + self.fold_boolean_expression(e2)?, ) { (BooleanExpression::Value(v1), BooleanExpression::Value(v2)) => { - BooleanExpression::Value(v1 || v2) + Ok(BooleanExpression::Value(v1 || v2)) } (_, BooleanExpression::Value(true)) | (BooleanExpression::Value(true), _) => { - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) } (e, BooleanExpression::Value(false)) | (BooleanExpression::Value(false), e) => { - e + Ok(e) } - (e1, e2) => BooleanExpression::Or(box e1, box e2), + (e1, e2) => Ok(BooleanExpression::Or(box e1, box e2)), } } BooleanExpression::And(box e1, box e2) => { match ( - self.fold_boolean_expression(e1), - self.fold_boolean_expression(e2), + self.fold_boolean_expression(e1)?, + self.fold_boolean_expression(e2)?, ) { - (BooleanExpression::Value(true), e) | (e, BooleanExpression::Value(true)) => e, + (BooleanExpression::Value(true), e) | (e, BooleanExpression::Value(true)) => { + Ok(e) + } (BooleanExpression::Value(false), _) | (_, BooleanExpression::Value(false)) => { - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) } - (e1, e2) => BooleanExpression::And(box e1, box e2), + (e1, e2) => Ok(BooleanExpression::And(box e1, box e2)), } } - BooleanExpression::Not(box e) => match self.fold_boolean_expression(e) { - BooleanExpression::Value(v) => BooleanExpression::Value(!v), - e => BooleanExpression::Not(box e), + BooleanExpression::Not(box e) => match self.fold_boolean_expression(e)? { + BooleanExpression::Value(v) => Ok(BooleanExpression::Value(!v)), + e => Ok(BooleanExpression::Not(box e)), }, BooleanExpression::IfElse(box condition, box consequence, box alternative) => { - let condition = self.fold_boolean_expression(condition); - let consequence = self.fold_boolean_expression(consequence); - let alternative = self.fold_boolean_expression(alternative); + let condition = self.fold_boolean_expression(condition)?; + let consequence = self.fold_boolean_expression(consequence)?; + let alternative = self.fold_boolean_expression(alternative)?; match (condition, consequence, alternative) { - (_, consequence, alternative) if consequence == alternative => consequence, - (BooleanExpression::Value(true), consequence, _) => consequence, - (BooleanExpression::Value(false), _, alternative) => alternative, - (condition, consequence, alternative) => { - BooleanExpression::IfElse(box condition, box consequence, box alternative) - } + (_, consequence, alternative) if consequence == alternative => Ok(consequence), + (BooleanExpression::Value(true), consequence, _) => Ok(consequence), + (BooleanExpression::Value(false), _, alternative) => Ok(alternative), + (condition, consequence, alternative) => Ok(BooleanExpression::IfElse( + box condition, + box consequence, + box alternative, + )), } } } @@ -401,200 +435,208 @@ impl<'ast, T: Field> Folder<'ast, T> for ZirPropagator<'ast, T> { &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { + ) -> Result, Self::Error> { match e { - UExpressionInner::Value(v) => UExpressionInner::Value(v), + UExpressionInner::Value(v) => Ok(UExpressionInner::Value(v)), UExpressionInner::Identifier(id) => match self.constants.get(&id) { - Some(ZirExpression::Uint(e)) => e.as_inner().clone(), - _ => UExpressionInner::Identifier(id), + Some(ZirExpression::Uint(e)) => Ok(e.as_inner().clone()), + _ => Ok(UExpressionInner::Identifier(id)), }, UExpressionInner::Select(e, box index) => { - let index = self.fold_uint_expression(index); + let index = self.fold_uint_expression(index)?; let e: Vec> = e .into_iter() .map(|e| self.fold_uint_expression(e)) - .collect(); + .collect::>()?; match index.into_inner() { UExpressionInner::Value(v) => e .get(v as usize) .cloned() - .expect("index out of bounds") - .into_inner(), - i => UExpressionInner::Select(e, box i.annotate(bitwidth)), + .ok_or(Error::OutOfBounds(v, e.len() as u128)) + .map(|e| e.into_inner()), + i => Ok(UExpressionInner::Select(e, box i.annotate(bitwidth))), } } UExpressionInner::Add(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { - (UExpressionInner::Value(0), e) | (e, UExpressionInner::Value(0)) => e, - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value((n1 + n2) % 2_u128.pow(bitwidth.to_usize() as u32)) - } - (e1, e2) => { - UExpressionInner::Add(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) - } + (UExpressionInner::Value(0), e) | (e, UExpressionInner::Value(0)) => Ok(e), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( + UExpressionInner::Value((n1 + n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + ), + (e1, e2) => Ok(UExpressionInner::Add( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::Sub(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) => e, + (e, UExpressionInner::Value(0)) => Ok(e), (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value( + Ok(UExpressionInner::Value( n1.wrapping_sub(n2) % 2_u128.pow(bitwidth.to_usize() as u32), - ) - } - (e1, e2) => { - UExpressionInner::Sub(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + )) } + (e1, e2) => Ok(UExpressionInner::Sub( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::Mult(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { - UExpressionInner::Value(0) - } - (e, UExpressionInner::Value(1)) | (UExpressionInner::Value(1), e) => e, - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value((n1 * n2) % 2_u128.pow(bitwidth.to_usize() as u32)) - } - (e1, e2) => { - UExpressionInner::Mult(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + Ok(UExpressionInner::Value(0)) } + (e, UExpressionInner::Value(1)) | (UExpressionInner::Value(1), e) => Ok(e), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( + UExpressionInner::Value((n1 * n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + ), + (e1, e2) => Ok(UExpressionInner::Mult( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::Div(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(n)) if n == 1 => e, - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)) - } - (e1, e2) => { - UExpressionInner::Div(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) - } + (e, UExpressionInner::Value(n)) if n == 1 => Ok(e), + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( + UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + ), + (e1, e2) => Ok(UExpressionInner::Div( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::Rem(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { - (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value((n1 % n2) % 2_u128.pow(bitwidth.to_usize() as u32)) - } - (e1, e2) => { - UExpressionInner::Rem(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) - } + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( + UExpressionInner::Value((n1 % n2) % 2_u128.pow(bitwidth.to_usize() as u32)), + ), + (e1, e2) => Ok(UExpressionInner::Rem( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::Xor(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 ^ n2) - } - (e1, e2) if e1.eq(&e2) => UExpressionInner::Value(0), - (e1, e2) => { - UExpressionInner::Xor(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + Ok(UExpressionInner::Value(n1 ^ n2)) } + (e1, e2) if e1.eq(&e2) => Ok(UExpressionInner::Value(0)), + (e1, e2) => Ok(UExpressionInner::Xor( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::And(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { (e, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), e) if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { - e + Ok(e) } (_, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), _) => { - UExpressionInner::Value(0) + Ok(UExpressionInner::Value(0)) } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 & n2) - } - (e1, e2) => { - UExpressionInner::And(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + Ok(UExpressionInner::Value(n1 & n2)) } + (e1, e2) => Ok(UExpressionInner::And( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::Or(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { - (e, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), e) => e, + (e, UExpressionInner::Value(0)) | (UExpressionInner::Value(0), e) => Ok(e), (_, UExpressionInner::Value(n)) | (UExpressionInner::Value(n), _) if n == 2_u128.pow(bitwidth.to_usize() as u32) - 1 => { - UExpressionInner::Value(n) + Ok(UExpressionInner::Value(n)) } (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { - UExpressionInner::Value(n1 | n2) - } - (e1, e2) => { - UExpressionInner::Or(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + Ok(UExpressionInner::Value(n1 | n2)) } + (e1, e2) => Ok(UExpressionInner::Or( + box e1.annotate(bitwidth), + box e2.annotate(bitwidth), + )), } } UExpressionInner::LeftShift(box e, by) => { - let e = self.fold_uint_expression(e); + let e = self.fold_uint_expression(e)?; match (e.into_inner(), by) { - (e, 0) => e, - (_, by) if by >= bitwidth as u32 => UExpressionInner::Value(0), - (UExpressionInner::Value(n), by) => { - UExpressionInner::Value((n << by) & (2_u128.pow(bitwidth as u32) - 1)) - } - (e, by) => UExpressionInner::LeftShift(box e.annotate(bitwidth), by), + (e, 0) => Ok(e), + (_, by) if by >= bitwidth as u32 => Ok(UExpressionInner::Value(0)), + (UExpressionInner::Value(n), by) => Ok(UExpressionInner::Value( + (n << by) & (2_u128.pow(bitwidth as u32) - 1), + )), + (e, by) => Ok(UExpressionInner::LeftShift(box e.annotate(bitwidth), by)), } } UExpressionInner::RightShift(box e, by) => { - let e = self.fold_uint_expression(e); + let e = self.fold_uint_expression(e)?; match (e.into_inner(), by) { - (e, 0) => e, - (_, by) if by >= bitwidth as u32 => UExpressionInner::Value(0), - (UExpressionInner::Value(n), by) => UExpressionInner::Value(n >> by), - (e, by) => UExpressionInner::RightShift(box e.annotate(bitwidth), by), + (e, 0) => Ok(e), + (_, by) if by >= bitwidth as u32 => Ok(UExpressionInner::Value(0)), + (UExpressionInner::Value(n), by) => Ok(UExpressionInner::Value(n >> by)), + (e, by) => Ok(UExpressionInner::RightShift(box e.annotate(bitwidth), by)), } } UExpressionInner::Not(box e) => { - let e = self.fold_uint_expression(e); + let e = self.fold_uint_expression(e)?; match e.into_inner() { - UExpressionInner::Value(n) => { - UExpressionInner::Value(!n & (2_u128.pow(bitwidth as u32) - 1)) - } - e => UExpressionInner::Not(box e.annotate(bitwidth)), + UExpressionInner::Value(n) => Ok(UExpressionInner::Value( + !n & (2_u128.pow(bitwidth as u32) - 1), + )), + e => Ok(UExpressionInner::Not(box e.annotate(bitwidth))), } } UExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let condition = self.fold_boolean_expression(condition); - let consequence = self.fold_uint_expression(consequence).into_inner(); - let alternative = self.fold_uint_expression(alternative).into_inner(); + let condition = self.fold_boolean_expression(condition)?; + let consequence = self.fold_uint_expression(consequence)?.into_inner(); + let alternative = self.fold_uint_expression(alternative)?.into_inner(); match (condition, consequence, alternative) { - (_, consequence, alternative) if consequence == alternative => consequence, - (BooleanExpression::Value(true), consequence, _) => consequence, - (BooleanExpression::Value(false), _, alternative) => alternative, - (condition, consequence, alternative) => UExpressionInner::IfElse( + (_, consequence, alternative) if consequence == alternative => Ok(consequence), + (BooleanExpression::Value(true), consequence, _) => Ok(consequence), + (BooleanExpression::Value(false), _, alternative) => Ok(alternative), + (condition, consequence, alternative) => Ok(UExpressionInner::IfElse( box condition, box consequence.annotate(bitwidth), box alternative.annotate(bitwidth), - ), + )), } } } @@ -623,7 +665,11 @@ mod tests { let mut propagator = ZirPropagator::default(); let statements: Vec> = statements .into_iter() - .flat_map(|s| propagator.fold_statement(s)) + .map(|s| propagator.fold_statement(s)) + .collect::, _>>() + .unwrap() + .into_iter() + .flatten() .collect(); assert_eq!( @@ -651,7 +697,18 @@ mod tests { ], box UExpressionInner::Value(1).annotate(UBitwidth::B32), )), - FieldElementExpression::Number(Bn128Field::from(2)) + Ok(FieldElementExpression::Number(Bn128Field::from(2))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Select( + vec![ + FieldElementExpression::Number(Bn128Field::from(1)), + FieldElementExpression::Number(Bn128Field::from(2)), + ], + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + Err(Error::OutOfBounds(3, 2)) ); } @@ -664,7 +721,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - FieldElementExpression::Number(Bn128Field::from(5)) + Ok(FieldElementExpression::Number(Bn128Field::from(5))) ); // a + 0 = a @@ -673,7 +730,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box FieldElementExpression::Number(Bn128Field::from(0)), )), - FieldElementExpression::Identifier("a".into()) + Ok(FieldElementExpression::Identifier("a".into())) ); } @@ -686,7 +743,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - FieldElementExpression::Number(Bn128Field::from(1)) + Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); // a - 0 = a @@ -695,7 +752,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box FieldElementExpression::Number(Bn128Field::from(0)), )), - FieldElementExpression::Identifier("a".into()) + Ok(FieldElementExpression::Identifier("a".into())) ); } @@ -708,7 +765,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - FieldElementExpression::Number(Bn128Field::from(6)) + Ok(FieldElementExpression::Number(Bn128Field::from(6))) ); // a * 0 = 0 @@ -717,7 +774,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box FieldElementExpression::Number(Bn128Field::from(0)), )), - FieldElementExpression::Number(Bn128Field::from(0)) + Ok(FieldElementExpression::Number(Bn128Field::from(0))) ); // a * 1 = a @@ -726,7 +783,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box FieldElementExpression::Number(Bn128Field::from(1)), )), - FieldElementExpression::Identifier("a".into()) + Ok(FieldElementExpression::Identifier("a".into())) ); } @@ -739,7 +796,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(6)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - FieldElementExpression::Number(Bn128Field::from(3)) + Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); assert_eq!( @@ -747,7 +804,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box FieldElementExpression::Number(Bn128Field::from(1)), )), - FieldElementExpression::Identifier("a".into()) + Ok(FieldElementExpression::Identifier("a".into())) ); } @@ -760,7 +817,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box UExpressionInner::Value(2).annotate(UBitwidth::B32), )), - FieldElementExpression::Number(Bn128Field::from(9)) + Ok(FieldElementExpression::Number(Bn128Field::from(9))) ); // a ** 0 = 1 @@ -769,7 +826,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box UExpressionInner::Value(0).annotate(UBitwidth::B32), )), - FieldElementExpression::Number(Bn128Field::from(1)) + Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); // a ** 1 = a @@ -778,7 +835,7 @@ mod tests { box FieldElementExpression::Identifier("a".into()), box UExpressionInner::Value(1).annotate(UBitwidth::B32), )), - FieldElementExpression::Identifier("a".into()) + Ok(FieldElementExpression::Identifier("a".into())) ); } @@ -792,7 +849,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(1)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - FieldElementExpression::Number(Bn128Field::from(1)) + Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); assert_eq!( @@ -801,7 +858,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(1)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - FieldElementExpression::Number(Bn128Field::from(2)) + Ok(FieldElementExpression::Number(Bn128Field::from(2))) ); assert_eq!( @@ -810,7 +867,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - FieldElementExpression::Number(Bn128Field::from(2)) + Ok(FieldElementExpression::Number(Bn128Field::from(2))) ); } } @@ -831,7 +888,18 @@ mod tests { ], box UExpressionInner::Value(1).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) + ); + + assert_eq!( + propagator.fold_boolean_expression(BooleanExpression::Select( + vec![ + BooleanExpression::Value(false), + BooleanExpression::Value(true), + ], + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + )), + Err(Error::OutOfBounds(3, 2)) ); } @@ -844,7 +912,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -852,7 +920,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -865,7 +933,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -873,7 +941,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } @@ -886,7 +954,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -894,7 +962,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } @@ -907,7 +975,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -915,7 +983,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(3)), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -928,7 +996,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -936,7 +1004,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(3)), box FieldElementExpression::Number(Bn128Field::from(2)), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -949,7 +1017,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -957,7 +1025,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -970,7 +1038,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -978,7 +1046,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } @@ -991,7 +1059,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), box UExpressionInner::Value(2).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -999,7 +1067,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } @@ -1012,7 +1080,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), box UExpressionInner::Value(2).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -1020,7 +1088,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -1033,7 +1101,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), box UExpressionInner::Value(2).annotate(UBitwidth::B32), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -1041,7 +1109,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), box UExpressionInner::Value(3).annotate(UBitwidth::B32), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -1054,7 +1122,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(true), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -1062,7 +1130,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(false), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -1075,7 +1143,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(true), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -1083,7 +1151,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(false), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); assert_eq!( @@ -1091,7 +1159,7 @@ mod tests { box BooleanExpression::Identifier("a".into()), box BooleanExpression::Value(true), )), - BooleanExpression::Identifier("a".into()) + Ok(BooleanExpression::Identifier("a".into())) ); assert_eq!( @@ -1099,7 +1167,7 @@ mod tests { box BooleanExpression::Identifier("a".into()), box BooleanExpression::Value(false), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); } @@ -1112,7 +1180,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(true), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -1120,7 +1188,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(false), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } @@ -1132,14 +1200,14 @@ mod tests { propagator.fold_boolean_expression(BooleanExpression::Not( box BooleanExpression::Value(true), )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); assert_eq!( propagator.fold_boolean_expression(BooleanExpression::Not( box BooleanExpression::Value(false), )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } @@ -1153,7 +1221,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(false) )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); assert_eq!( @@ -1162,7 +1230,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(false) )), - BooleanExpression::Value(false) + Ok(BooleanExpression::Value(false)) ); assert_eq!( @@ -1171,7 +1239,7 @@ mod tests { box BooleanExpression::Value(true), box BooleanExpression::Value(true) )), - BooleanExpression::Value(true) + Ok(BooleanExpression::Value(true)) ); } } @@ -1195,7 +1263,21 @@ mod tests { box UExpressionInner::Value(1).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(2) + Ok(UExpressionInner::Value(2)) + ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Select( + vec![ + UExpressionInner::Value(1).annotate(UBitwidth::B32), + UExpressionInner::Value(2).annotate(UBitwidth::B32), + ], + box UExpressionInner::Value(3).annotate(UBitwidth::B32), + ) + ), + Err(Error::OutOfBounds(3, 2)) ); } @@ -1211,7 +1293,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(5) + Ok(UExpressionInner::Value(5)) ); // a + 0 = a @@ -1223,7 +1305,7 @@ mod tests { box UExpressionInner::Value(0).annotate(UBitwidth::B32), ) ), - UExpressionInner::Identifier("a".into()) + Ok(UExpressionInner::Identifier("a".into())) ); } @@ -1239,7 +1321,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(1) + Ok(UExpressionInner::Value(1)) ); // a - 0 = a @@ -1251,7 +1333,7 @@ mod tests { box UExpressionInner::Value(0).annotate(UBitwidth::B32), ) ), - UExpressionInner::Identifier("a".into()) + Ok(UExpressionInner::Identifier("a".into())) ); } @@ -1267,7 +1349,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(6) + Ok(UExpressionInner::Value(6)) ); // a * 1 = a @@ -1279,7 +1361,7 @@ mod tests { box UExpressionInner::Value(1).annotate(UBitwidth::B32), ) ), - UExpressionInner::Identifier("a".into()) + Ok(UExpressionInner::Identifier("a".into())) ); // a * 0 = 0 @@ -1291,7 +1373,7 @@ mod tests { box UExpressionInner::Value(0).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(0) + Ok(UExpressionInner::Value(0)) ); } @@ -1307,7 +1389,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(3) + Ok(UExpressionInner::Value(3)) ); assert_eq!( @@ -1318,7 +1400,7 @@ mod tests { box UExpressionInner::Value(1).annotate(UBitwidth::B32), ) ), - UExpressionInner::Identifier("a".into()) + Ok(UExpressionInner::Identifier("a".into())) ); } @@ -1334,7 +1416,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(2) + Ok(UExpressionInner::Value(2)) ); assert_eq!( @@ -1345,7 +1427,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(1) + Ok(UExpressionInner::Value(1)) ); } @@ -1361,7 +1443,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(1) + Ok(UExpressionInner::Value(1)) ); assert_eq!( @@ -1372,7 +1454,7 @@ mod tests { box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(0) + Ok(UExpressionInner::Value(0)) ); } @@ -1388,7 +1470,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(2) + Ok(UExpressionInner::Value(2)) ); assert_eq!( @@ -1399,7 +1481,7 @@ mod tests { box UExpressionInner::Value(0).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(0) + Ok(UExpressionInner::Value(0)) ); assert_eq!( @@ -1410,7 +1492,7 @@ mod tests { box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), ) ), - UExpressionInner::Identifier("a".into()) + Ok(UExpressionInner::Identifier("a".into())) ); } @@ -1426,7 +1508,7 @@ mod tests { box UExpressionInner::Value(3).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(3) + Ok(UExpressionInner::Value(3)) ); assert_eq!( @@ -1437,7 +1519,7 @@ mod tests { box UExpressionInner::Value(0).annotate(UBitwidth::B32), ) ), - UExpressionInner::Identifier("a".into()) + Ok(UExpressionInner::Identifier("a".into())) ); assert_eq!( @@ -1448,7 +1530,7 @@ mod tests { box UExpressionInner::Value(u32::MAX as u128).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(u32::MAX as u128) + Ok(UExpressionInner::Value(u32::MAX as u128)) ); } @@ -1464,7 +1546,7 @@ mod tests { 3, ) ), - UExpressionInner::Value(16) + Ok(UExpressionInner::Value(16)) ); assert_eq!( @@ -1475,7 +1557,7 @@ mod tests { 0, ) ), - UExpressionInner::Value(2) + Ok(UExpressionInner::Value(2)) ); assert_eq!( @@ -1486,7 +1568,7 @@ mod tests { 32, ) ), - UExpressionInner::Value(0) + Ok(UExpressionInner::Value(0)) ); } @@ -1502,7 +1584,7 @@ mod tests { 2, ) ), - UExpressionInner::Value(1) + Ok(UExpressionInner::Value(1)) ); assert_eq!( @@ -1513,7 +1595,7 @@ mod tests { 0, ) ), - UExpressionInner::Value(4) + Ok(UExpressionInner::Value(4)) ); assert_eq!( @@ -1524,7 +1606,7 @@ mod tests { 32, ) ), - UExpressionInner::Value(0) + Ok(UExpressionInner::Value(0)) ); } @@ -1537,7 +1619,7 @@ mod tests { UBitwidth::B32, UExpressionInner::Not(box UExpressionInner::Value(2).annotate(UBitwidth::B32),) ), - UExpressionInner::Value(4294967293) + Ok(UExpressionInner::Value(4294967293)) ); } @@ -1554,7 +1636,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(1) + Ok(UExpressionInner::Value(1)) ); assert_eq!( @@ -1566,7 +1648,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(2) + Ok(UExpressionInner::Value(2)) ); assert_eq!( @@ -1578,7 +1660,7 @@ mod tests { box UExpressionInner::Value(2).annotate(UBitwidth::B32), ) ), - UExpressionInner::Value(2) + Ok(UExpressionInner::Value(2)) ); } } diff --git a/zokrates_core/src/zir/mod.rs b/zokrates_core/src/zir/mod.rs index a3bbe705d..db5ab4a69 100644 --- a/zokrates_core/src/zir/mod.rs +++ b/zokrates_core/src/zir/mod.rs @@ -2,6 +2,7 @@ pub mod folder; mod from_typed; mod identifier; mod parameter; +pub mod result_folder; pub mod types; mod uint; mod variable; diff --git a/zokrates_core/src/zir/result_folder.rs b/zokrates_core/src/zir/result_folder.rs index e69de29bb..c01592e72 100644 --- a/zokrates_core/src/zir/result_folder.rs +++ b/zokrates_core/src/zir/result_folder.rs @@ -0,0 +1,418 @@ +// Generic walk through a typed AST. Not mutating in place + +use crate::zir::types::UBitwidth; +use crate::zir::*; +use zokrates_field::Field; + +pub trait ResultFolder<'ast, T: Field>: Sized { + type Error; + + fn fold_program(&mut self, p: ZirProgram<'ast, T>) -> Result, Self::Error> { + fold_program(self, p) + } + + fn fold_function( + &mut self, + f: ZirFunction<'ast, T>, + ) -> Result, Self::Error> { + fold_function(self, f) + } + + fn fold_parameter(&mut self, p: Parameter<'ast>) -> Result, Self::Error> { + Ok(Parameter { + id: self.fold_variable(p.id)?, + ..p + }) + } + + fn fold_name(&mut self, n: Identifier<'ast>) -> Result, Self::Error> { + Ok(n) + } + + fn fold_variable(&mut self, v: Variable<'ast>) -> Result, Self::Error> { + Ok(Variable { + id: self.fold_name(v.id)?, + ..v + }) + } + + fn fold_assignee(&mut self, a: ZirAssignee<'ast>) -> Result, Self::Error> { + self.fold_variable(a) + } + + fn fold_statement( + &mut self, + s: ZirStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement(self, s) + } + + fn fold_expression( + &mut self, + e: ZirExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + ZirExpression::FieldElement(e) => Ok(self.fold_field_expression(e)?.into()), + ZirExpression::Boolean(e) => Ok(self.fold_boolean_expression(e)?.into()), + ZirExpression::Uint(e) => Ok(self.fold_uint_expression(e)?.into()), + } + } + + fn fold_expression_list( + &mut self, + es: ZirExpressionList<'ast, T>, + ) -> Result, Self::Error> { + match es { + ZirExpressionList::EmbedCall(embed, generics, arguments) => { + Ok(ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::>()?, + )) + } + } + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression(self, e) + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression(self, e) + } + + fn fold_uint_expression( + &mut self, + e: UExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression(self, e) + } + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_inner(self, bitwidth, e) + } +} + +pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: ZirStatement<'ast, T>, +) -> Result>, F::Error> { + let res = match s { + ZirStatement::Return(expressions) => ZirStatement::Return( + expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ), + ZirStatement::Definition(a, e) => { + ZirStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?) + } + ZirStatement::IfElse(condition, consequence, alternative) => ZirStatement::IfElse( + f.fold_boolean_expression(condition)?, + consequence + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + alternative + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ), + ZirStatement::Assertion(e) => ZirStatement::Assertion(f.fold_boolean_expression(e)?), + ZirStatement::MultipleDefinition(variables, elist) => ZirStatement::MultipleDefinition( + variables + .into_iter() + .map(|v| f.fold_assignee(v)) + .collect::>()?, + f.fold_expression_list(elist)?, + ), + }; + Ok(vec![res]) +} + +pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + Ok(match e { + FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Identifier(id) => { + FieldElementExpression::Identifier(f.fold_name(id)?) + } + FieldElementExpression::Select(a, box i) => FieldElementExpression::Select( + a.into_iter() + .map(|a| f.fold_field_expression(a)) + .collect::>()?, + box f.fold_uint_expression(i)?, + ), + FieldElementExpression::Add(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Add(box e1, box e2) + } + FieldElementExpression::Sub(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Sub(box e1, box e2) + } + FieldElementExpression::Mult(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Mult(box e1, box e2) + } + FieldElementExpression::Div(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Div(box e1, box e2) + } + FieldElementExpression::Pow(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + FieldElementExpression::Pow(box e1, box e2) + } + FieldElementExpression::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_field_expression(cons)?; + let alt = f.fold_field_expression(alt)?; + FieldElementExpression::IfElse(box cond, box cons, box alt) + } + }) +} + +pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + Ok(match e { + BooleanExpression::Value(v) => BooleanExpression::Value(v), + BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?), + BooleanExpression::Select(a, box i) => BooleanExpression::Select( + a.into_iter() + .map(|a| f.fold_boolean_expression(a)) + .collect::>()?, + box f.fold_uint_expression(i)?, + ), + BooleanExpression::FieldEq(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldEq(box e1, box e2) + } + BooleanExpression::BoolEq(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::BoolEq(box e1, box e2) + } + BooleanExpression::UintEq(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintEq(box e1, box e2) + } + BooleanExpression::FieldLt(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldLt(box e1, box e2) + } + BooleanExpression::FieldLe(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldLe(box e1, box e2) + } + BooleanExpression::FieldGt(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldGt(box e1, box e2) + } + BooleanExpression::FieldGe(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldGe(box e1, box e2) + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintLt(box e1, box e2) + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintLe(box e1, box e2) + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintGt(box e1, box e2) + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintGe(box e1, box e2) + } + BooleanExpression::Or(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::Or(box e1, box e2) + } + BooleanExpression::And(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::And(box e1, box e2) + } + BooleanExpression::Not(box e) => { + let e = f.fold_boolean_expression(e)?; + BooleanExpression::Not(box e) + } + BooleanExpression::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_boolean_expression(cons)?; + let alt = f.fold_boolean_expression(alt)?; + BooleanExpression::IfElse(box cond, box cons, box alt) + } + }) +} + +pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: UExpression<'ast, T>, +) -> Result, F::Error> { + Ok(UExpression { + inner: f.fold_uint_expression_inner(e.bitwidth, e.inner)?, + ..e + }) +} + +pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + Ok(match e { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?), + UExpressionInner::Select(a, box i) => UExpressionInner::Select( + a.into_iter() + .map(|a| f.fold_uint_expression(a)) + .collect::>()?, + box f.fold_uint_expression(i)?, + ), + UExpressionInner::Add(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Add(box left, box right) + } + UExpressionInner::Sub(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Sub(box left, box right) + } + UExpressionInner::Mult(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Mult(box left, box right) + } + UExpressionInner::Div(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Div(box left, box right) + } + UExpressionInner::Rem(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Rem(box left, box right) + } + UExpressionInner::Xor(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Xor(box left, box right) + } + UExpressionInner::And(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::And(box left, box right) + } + UExpressionInner::Or(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Or(box left, box right) + } + UExpressionInner::LeftShift(box e, by) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::LeftShift(box e, by) + } + UExpressionInner::RightShift(box e, by) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::RightShift(box e, by) + } + UExpressionInner::Not(box e) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::Not(box e) + } + UExpressionInner::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_uint_expression(cons)?; + let alt = f.fold_uint_expression(alt)?; + + UExpressionInner::IfElse(box cond, box cons, box alt) + } + }) +} + +pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + fun: ZirFunction<'ast, T>, +) -> Result, F::Error> { + Ok(ZirFunction { + arguments: fun + .arguments + .into_iter() + .map(|a| f.fold_parameter(a)) + .collect::>()?, + statements: fun + .statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ..fun + }) +} + +pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + p: ZirProgram<'ast, T>, +) -> Result, F::Error> { + Ok(ZirProgram { + main: f.fold_function(p.main)?, + }) +} From 6ea4cfbe84e5081130fc9764e43e7828e8621f72 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 17 Aug 2021 18:44:25 +0200 Subject: [PATCH 07/78] div by zero --- .../examples/compile_errors/div_by_zero.zok | 3 ++ .../src/static_analysis/zir_propagation.rs | 33 +++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/div_by_zero.zok diff --git a/zokrates_cli/examples/compile_errors/div_by_zero.zok b/zokrates_cli/examples/compile_errors/div_by_zero.zok new file mode 100644 index 000000000..733ddaa9b --- /dev/null +++ b/zokrates_cli/examples/compile_errors/div_by_zero.zok @@ -0,0 +1,3 @@ +def main(field input) -> field: + field divisor = if true then 0 else 1 fi + return input / divisor \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 220e6916f..9b09a59ee 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -14,6 +14,7 @@ type Constants<'ast, T> = HashMap, ZirExpression<'ast, T>>; #[derive(Debug, PartialEq)] pub enum Error { OutOfBounds(u128, u128), + DivisionByZero, } impl fmt::Display for Error { @@ -24,6 +25,9 @@ impl fmt::Display for Error { "Out of bounds index ({} >= {}) found in zir during static analysis", index, size ), + Error::DivisionByZero => { + write!(f, "Division by zero detected in zir during static analysis",) + } } } } @@ -105,7 +109,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { UExpressionInner::Value(v) => e .get(v as usize) .cloned() - .ok_or(Error::OutOfBounds(v, e.len() as u128)), + .ok_or_else(|| Error::OutOfBounds(v, e.len() as u128)), i => Ok(FieldElementExpression::Select( e, box i.annotate(UBitwidth::B32), @@ -169,6 +173,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { self.fold_field_expression(e1)?, self.fold_field_expression(e2)?, ) { + (_, FieldElementExpression::Number(n)) if n == T::from(0) => { + Err(Error::DivisionByZero) + } (e, FieldElementExpression::Number(n)) if n == T::from(1) => Ok(e), (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { Ok(FieldElementExpression::Number(n1 / n2)) @@ -234,7 +241,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { UExpressionInner::Value(v) => e .get(*v as usize) .cloned() - .ok_or(Error::OutOfBounds(*v, e.len() as u128)), + .ok_or_else(|| Error::OutOfBounds(*v, e.len() as u128)), _ => Ok(BooleanExpression::Select(e, box index)), } } @@ -453,7 +460,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { UExpressionInner::Value(v) => e .get(v as usize) .cloned() - .ok_or(Error::OutOfBounds(v, e.len() as u128)) + .ok_or_else(|| Error::OutOfBounds(v, e.len() as u128)) .map(|e| e.into_inner()), i => Ok(UExpressionInner::Select(e, box i.annotate(bitwidth))), } @@ -513,6 +520,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { let e2 = self.fold_uint_expression(e2)?; match (e1.into_inner(), e2.into_inner()) { + (_, UExpressionInner::Value(n)) if n == 0 => Err(Error::DivisionByZero), (e, UExpressionInner::Value(n)) if n == 1 => Ok(e), (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => Ok( UExpressionInner::Value((n1 / n2) % 2_u128.pow(bitwidth.to_usize() as u32)), @@ -806,6 +814,14 @@ mod tests { )), Ok(FieldElementExpression::Identifier("a".into())) ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::Div( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(0)), + )), + Err(Error::DivisionByZero) + ); } #[test] @@ -1402,6 +1418,17 @@ mod tests { ), Ok(UExpressionInner::Identifier("a".into())) ); + + assert_eq!( + propagator.fold_uint_expression_inner( + UBitwidth::B32, + UExpressionInner::Div( + box UExpressionInner::Identifier("a".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + ) + ), + Err(Error::DivisionByZero) + ); } #[test] From b5f243965d8bdb1c9d1204f7433d1eee1bfe3248 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 18 Aug 2021 23:51:00 +0200 Subject: [PATCH 08/78] detect mistyped constant during constant inlining --- .../const_complex_type_mismatch.zok | 7 + .../src/static_analysis/constant_inliner.rs | 131 +++++++++++------- zokrates_core/src/static_analysis/mod.rs | 10 +- .../src/static_analysis/propagation.rs | 13 +- zokrates_core/src/typed_absy/result_folder.rs | 8 +- 5 files changed, 113 insertions(+), 56 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/const_complex_type_mismatch.zok diff --git a/zokrates_cli/examples/compile_errors/const_complex_type_mismatch.zok b/zokrates_cli/examples/compile_errors/const_complex_type_mismatch.zok new file mode 100644 index 000000000..97d5c9ccb --- /dev/null +++ b/zokrates_cli/examples/compile_errors/const_complex_type_mismatch.zok @@ -0,0 +1,7 @@ +const u32 ONE = 1 +const u32 TWO = 2 +const field[ONE] ONE_FIELD = [1; TWO] // actually set the value to an array of 2 elements + +def main(field[TWO] TWO_FIELDS): + assert(TWO_FIELDS == ONE_FIELD) // use the value as is + return \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 0e2ace9c7..1ad506e3e 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -1,15 +1,27 @@ use crate::static_analysis::Propagator; -use crate::typed_absy::folder::*; -use crate::typed_absy::result_folder::ResultFolder; +use crate::typed_absy::result_folder::*; use crate::typed_absy::types::DeclarationConstant; use crate::typed_absy::*; use std::collections::HashMap; use std::convert::TryInto; +use std::fmt; use zokrates_field::Field; type ProgramConstants<'ast, T> = HashMap, TypedExpression<'ast, T>>>; +#[derive(Debug, PartialEq)] +pub enum Error { + Type(String), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Type(s) => write!(f, "{}", s), + } + } +} pub struct ConstantInliner<'ast, T> { modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, @@ -28,7 +40,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { constants, } } - pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + pub fn inline(p: TypedProgram<'ast, T>) -> Result, Error> { let constants = ProgramConstants::new(); let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants); inliner.fold_program(p) @@ -66,44 +78,58 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { } } -impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { - fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - self.fold_module_id(p.main.clone()); +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { + type Error = Error; - TypedProgram { + fn fold_program( + &mut self, + p: TypedProgram<'ast, T>, + ) -> Result, Self::Error> { + self.fold_module_id(p.main.clone())?; + + Ok(TypedProgram { modules: std::mem::take(&mut self.modules), ..p - } + }) } - fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { + fn fold_module_id( + &mut self, + id: OwnedTypedModuleId, + ) -> Result { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); let m = self.modules.remove(&id).unwrap(); - let m = self.fold_module(m); + let m = self.fold_module(m)?; self.modules.insert(id.clone(), m); self.change_location(current_m_id); } - id + Ok(id) } - fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - TypedModule { + fn fold_module( + &mut self, + m: TypedModule<'ast, T>, + ) -> Result, Self::Error> { + Ok(TypedModule { constants: m .constants .into_iter() .map(|(id, tc)| { + + let id = self.fold_canonical_constant_identifier(id)?; + let constant = match tc { TypedConstantSymbol::There(imported_id) => { // visit the imported symbol. This triggers visiting the corresponding module if needed - let imported_id = self.fold_canonical_constant_identifier(imported_id); + let imported_id = self.fold_canonical_constant_identifier(imported_id)?; // after that, the constant must have been defined defined in the global map. It is already reduced // to a literal, so running propagation isn't required self.get_constant(&imported_id).unwrap() } TypedConstantSymbol::Here(c) => { - let non_propagated_constant = fold_constant(self, c).expression; + let non_propagated_constant = fold_constant(self, c)?.expression; // folding the constant above only reduces it to an expression containing only literals, not to a single literal. // propagating with an empty map of constants reduces it to a single literal Propagator::with_constants(&mut HashMap::default()) @@ -112,65 +138,72 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { } }; - // add to the constant map. The value added is always a single litteral - self.constants - .get_mut(&self.location) - .unwrap() - .insert(id.id.into(), constant.clone()); - - ( - id, - TypedConstantSymbol::Here(TypedConstant { - expression: constant, - }), - ) + if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(*id.ty.clone()).unwrap() == constant.get_type() { + // add to the constant map. The value added is always a single litteral + self.constants + .get_mut(&self.location) + .unwrap() + .insert(id.id.into(), constant.clone()); + + Ok(( + id, + TypedConstantSymbol::Here(TypedConstant { + expression: constant, + }), + )) + } else { + Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant.get_type(), id.id, id.ty))) + } }) - .collect(), + .collect::, _>>()?, functions: m .functions .into_iter() - .map(|(key, fun)| { - ( - self.fold_declaration_function_key(key), - self.fold_function_symbol(fun), - ) + .map::, _>(|(key, fun)| { + Ok(( + self.fold_declaration_function_key(key)?, + self.fold_function_symbol(fun)?, + )) }) + .collect::, _>>() + .into_iter() + .flatten() .collect(), - } + }) } fn fold_declaration_constant( &mut self, c: DeclarationConstant<'ast>, - ) -> DeclarationConstant<'ast> { + ) -> Result, Self::Error> { match c { // replace constants by their concrete value in declaration types DeclarationConstant::Constant(id) => { let id = CanonicalConstantIdentifier { - module: self.fold_module_id(id.module), + module: self.fold_module_id(id.module)?, ..id }; - DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() { + Ok(DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() { TypedExpression::Uint(UExpression { inner: UExpressionInner::Value(v), .. }) => v as u32, _ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"), - }) + })) } - c => c, + c => Ok(c), } } fn fold_field_expression( &mut self, e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { + ) -> Result, Self::Error> { match e { FieldElementExpression::Identifier(ref id) => { match self.get_constant_for_identifier(id) { - Some(c) => c.try_into().unwrap(), + Some(c) => Ok(c.try_into().unwrap()), None => fold_field_expression(self, e), } } @@ -181,10 +214,10 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { + ) -> Result, Self::Error> { match e { BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) { - Some(c) => c.try_into().unwrap(), + Some(c) => Ok(c.try_into().unwrap()), None => fold_boolean_expression(self, e), }, e => fold_boolean_expression(self, e), @@ -195,12 +228,12 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { &mut self, size: UBitwidth, e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { + ) -> Result, Self::Error> { match e { UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) { Some(c) => { let e: UExpression<'ast, T> = c.try_into().unwrap(); - e.into_inner() + Ok(e.into_inner()) } None => fold_uint_expression_inner(self, size, e), }, @@ -212,13 +245,13 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { &mut self, ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, - ) -> ArrayExpressionInner<'ast, T> { + ) -> Result, Self::Error> { match e { ArrayExpressionInner::Identifier(ref id) => { match self.get_constant_for_identifier(id) { Some(c) => { let e: ArrayExpression<'ast, T> = c.try_into().unwrap(); - e.into_inner() + Ok(e.into_inner()) } None => fold_array_expression_inner(self, ty, e), } @@ -231,13 +264,13 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { &mut self, ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, - ) -> StructExpressionInner<'ast, T> { + ) -> Result, Self::Error> { match e { StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) { Some(c) => { let e: StructExpression<'ast, T> = c.try_into().unwrap(); - e.into_inner() + Ok(e.into_inner()) } None => fold_struct_expression_inner(self, ty, e), }, diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 3b8751f1b..2285115b4 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -40,6 +40,13 @@ pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), NonConstantArgument(self::constant_argument_checker::Error), + ConstantInliner(self::constant_inliner::Error), +} + +impl From for Error { + fn from(e: self::constant_inliner::Error) -> Self { + Error::ConstantInliner(e) + } } impl From for Error { @@ -66,6 +73,7 @@ impl fmt::Display for Error { Error::Reducer(e) => write!(f, "{}", e), Error::Propagation(e) => write!(f, "{}", e), Error::NonConstantArgument(e) => write!(f, "{}", e), + Error::ConstantInliner(e) => write!(f, "{}", e), } } } @@ -74,7 +82,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> { // inline user-defined constants log::debug!("Static analyser: Inline constants"); - let r = ConstantInliner::inline(self); + let r = ConstantInliner::inline(self).map_err(Error::from)?; log::trace!("\n{}", r); // isolate branches diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index c66dc60b1..b90fc99d6 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -82,12 +82,15 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { Ok((variable, constant)) => match index.as_inner() { UExpressionInner::Value(n) => match constant { TypedExpression::Array(a) => match a.as_inner_mut() { - ArrayExpressionInner::Value(value) => match value.0[*n as usize] { - TypedExpressionOrSpread::Expression(ref mut e) => { - Ok((variable, e)) + ArrayExpressionInner::Value(value) => { + match value.0.get_mut(*n as usize) { + Some(TypedExpressionOrSpread::Expression(ref mut e)) => { + Ok((variable, e)) + } + None => Err(variable), + _ => unreachable!(), } - _ => unreachable!(), - }, + } _ => unreachable!("should be an array value"), }, _ => unreachable!("should be an array expression"), diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 162554706..c85e97a7d 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -248,7 +248,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { &mut self, t: DeclarationType<'ast>, ) -> Result, Self::Error> { - Ok(t) + use self::GType::*; + + match t { + Array(array_type) => Ok(Array(self.fold_declaration_array_type(array_type)?)), + Struct(struct_type) => Ok(Struct(self.fold_declaration_struct_type(struct_type)?)), + t => Ok(t), + } } fn fold_declaration_array_type( From bc54c1d4547dce62294caffffe0ede4308cca325 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 19 Aug 2021 15:58:44 +0200 Subject: [PATCH 09/78] use remap-path-prefix in circleci for release builds --- .circleci/config.yml | 2 +- build_release.sh | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 12937a6c4..22d9d1b2a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -168,7 +168,7 @@ jobs: - run: name: Build no_output_timeout: "30m" - command: << parameters.build-with >> build --target << parameters.target >> --release + command: RUSTFLAGS="--remap-path-prefix=$PWD=" << parameters.build-with >> build --target << parameters.target >> --release - tar_artifacts: target: << parameters.target >> publish_artifacts: diff --git a/build_release.sh b/build_release.sh index e0c727c25..d57d78c73 100755 --- a/build_release.sh +++ b/build_release.sh @@ -2,6 +2,7 @@ # Exit if any subcommand fails set -e +export RUSTFLAGS="--remap-path-prefix=$PWD=" if [ -n "$WITH_LIBSNARK" ]; then cargo build --release --package zokrates_cli --features="libsnark" From db9abf4519951c9a1ad82c18f9787b16f43cbe81 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 20 Aug 2021 13:43:34 +0200 Subject: [PATCH 10/78] force reduce x_to_bits embed arguments --- zokrates_core/src/flatten/mod.rs | 5 +- .../src/static_analysis/uint_optimizer.rs | 56 +++++++++++++------ zokrates_core/src/zir/folder.rs | 2 +- zokrates_core/src/zir/result_folder.rs | 2 +- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 901355eb9..dffe54325 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1089,7 +1089,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { bitwidth: UBitwidth, ) -> Vec> { let expression = UExpression::try_from(expression).unwrap(); - let from = expression.metadata.as_ref().unwrap().bitwidth(); + let metadata = expression.metadata.as_ref().unwrap(); + assert!(!metadata.should_reduce.is_unknown()); + + let from = metadata.bitwidth(); let p = self.flatten_uint_expression(statements_flattened, expression); self.get_bits(&p, from as usize, bitwidth, statements_flattened) .into_iter() diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index 965f30180..ab8911a7e 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -492,24 +492,44 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { _ => {} }; - vec![ZirStatement::MultipleDefinition( - lhs, - ZirExpressionList::EmbedCall( - embed, - generics, - arguments - .into_iter() - .map(|e| match e { - ZirExpression::Uint(e) => { - let e = self.fold_uint_expression(e); - let e = force_no_reduce(e); - ZirExpression::Uint(e) - } - e => self.fold_expression(e), - }) - .collect(), - ), - )] + match embed { + FlatEmbed::U8ToBits + | FlatEmbed::U16ToBits + | FlatEmbed::U32ToBits + | FlatEmbed::U64ToBits => { + vec![ZirStatement::MultipleDefinition( + lhs, + ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|e| match e { + ZirExpression::Uint(e) => { + let e = self.fold_uint_expression(e); + let e = force_reduce(e); + ZirExpression::Uint(e) + } + e => self.fold_expression(e), + }) + .collect(), + ), + )] + } + _ => { + vec![ZirStatement::MultipleDefinition( + lhs, + ZirExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|e| self.fold_expression(e)) + .collect(), + ), + )] + } + } } ZirStatement::Assertion(BooleanExpression::UintEq(box left, box right)) => { let left = self.fold_uint_expression(left); diff --git a/zokrates_core/src/zir/folder.rs b/zokrates_core/src/zir/folder.rs index fc4ccc9d1..b8f55e8d1 100644 --- a/zokrates_core/src/zir/folder.rs +++ b/zokrates_core/src/zir/folder.rs @@ -1,4 +1,4 @@ -// Generic walk through a typed AST. Not mutating in place +// Generic walk through ZIR. Not mutating in place use crate::zir::types::UBitwidth; use crate::zir::*; diff --git a/zokrates_core/src/zir/result_folder.rs b/zokrates_core/src/zir/result_folder.rs index c01592e72..791e16baa 100644 --- a/zokrates_core/src/zir/result_folder.rs +++ b/zokrates_core/src/zir/result_folder.rs @@ -1,4 +1,4 @@ -// Generic walk through a typed AST. Not mutating in place +// Generic walk through ZIR. Not mutating in place use crate::zir::types::UBitwidth; use crate::zir::*; From c71b31d1af394d332d45eb2f85e73b122fd61700 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sat, 21 Aug 2021 00:08:12 +0200 Subject: [PATCH 11/78] refactor constants to keep track of them across modules --- zokrates_cli/examples/call_in_const.zok | 6 + zokrates_cli/examples/call_in_const_aux.zok | 9 + .../compile_errors/call_in_constant.zok | 9 + zokrates_core/src/semantics.rs | 18 +- .../src/static_analysis/constant_inliner.rs | 172 ++++++++---------- zokrates_core/src/typed_absy/folder.rs | 2 +- zokrates_core/src/typed_absy/identifier.rs | 9 + zokrates_core/src/typed_absy/mod.rs | 7 +- zokrates_core/src/typed_absy/result_folder.rs | 2 +- zokrates_core/src/typed_absy/types.rs | 25 +-- 10 files changed, 132 insertions(+), 127 deletions(-) create mode 100644 zokrates_cli/examples/call_in_const.zok create mode 100644 zokrates_cli/examples/call_in_const_aux.zok create mode 100644 zokrates_cli/examples/compile_errors/call_in_constant.zok diff --git a/zokrates_cli/examples/call_in_const.zok b/zokrates_cli/examples/call_in_const.zok new file mode 100644 index 000000000..bf3cac375 --- /dev/null +++ b/zokrates_cli/examples/call_in_const.zok @@ -0,0 +1,6 @@ +from "./call_in_const_aux.zok" import A, foo, F +const field[A] Y = [...foo::(F)[..A - 1], 1] + +def main(field[A] X): + assert(X == Y) + return \ No newline at end of file diff --git a/zokrates_cli/examples/call_in_const_aux.zok b/zokrates_cli/examples/call_in_const_aux.zok new file mode 100644 index 000000000..2e3e86b71 --- /dev/null +++ b/zokrates_cli/examples/call_in_const_aux.zok @@ -0,0 +1,9 @@ +const field F = 10 +const u32 A = 10 +const u32 B = A + +def foo(field X) -> field[N]: + return [X; N] + +def main(): + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/call_in_constant.zok b/zokrates_cli/examples/compile_errors/call_in_constant.zok new file mode 100644 index 000000000..5f18f0d98 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/call_in_constant.zok @@ -0,0 +1,9 @@ +// calling a function inside a constant definition is not possible yet + +def yes() -> bool: + return true + +const TRUE = yes() + +def main(): + return \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 937fbf174..28cfb1714 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -355,7 +355,7 @@ impl<'ast, T: Field> Checker<'ast, T> { c: ConstantDefinitionNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result<(DeclarationType<'ast>, TypedConstant<'ast, T>), ErrorInner> { + ) -> Result, ErrorInner> { let pos = c.pos(); let ty = self.check_declaration_type( @@ -397,7 +397,7 @@ impl<'ast, T: Field> Checker<'ast, T> { ty ), }) - .map(|e| (ty, TypedConstant::new(e))) + .map(|e| TypedConstant::new(e, ty)) } fn check_struct_type_declaration( @@ -501,6 +501,9 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } + // pb: we convert canonical constants into identifiers inside rhs of constant definitions, loosing the module they are from. but then we want to reduce them to literals, which requires knowing which module they come + // if we don't convert, we end up with a new type of core identifier (implemented now) which confuses propagation because they are not equal to the identifier of the same name + fn check_symbol_declaration( &mut self, declaration: SymbolDeclarationNode<'ast>, @@ -554,7 +557,7 @@ impl<'ast, T: Field> Checker<'ast, T> { } Symbol::Here(SymbolDefinition::Constant(c)) => { match self.check_constant_definition(declaration.id, c, module_id, state) { - Ok((d_t, c)) => { + Ok(c) => { match symbol_unifier.insert_constant(declaration.id) { false => errors.push( ErrorInner { @@ -571,7 +574,6 @@ impl<'ast, T: Field> Checker<'ast, T> { CanonicalConstantIdentifier::new( declaration.id, module_id.into(), - d_t.clone(), ), TypedConstantSymbol::Here(c.clone()), )); @@ -583,7 +585,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .constants .entry(module_id.to_path_buf()) .or_default() - .insert(declaration.id, d_t) + .insert(declaration.id, c.ty) .is_none()); } }; @@ -722,8 +724,8 @@ impl<'ast, T: Field> Checker<'ast, T> { }}); } true => { - let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id, ty.clone()); - let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into(), ty.clone()); + let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id); + let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into()); constants.push((id.clone(), TypedConstantSymbol::There(imported_id))); self.insert_into_scope(Variable::with_id_and_type(declaration.id, crate::typed_absy::types::try_from_g_type(ty.clone()).unwrap())); @@ -1301,7 +1303,7 @@ impl<'ast, T: Field> Checker<'ast, T> { match (constants_map.get(name), generics_map.get(&name)) { (Some(ty), None) => { match ty { - DeclarationType::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into(), DeclarationType::Uint(UBitwidth::B32)))), + DeclarationType::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))), _ => Err(ErrorInner { pos: Some(pos), message: format!( diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 1ad506e3e..a8423dda1 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -8,7 +8,7 @@ use std::fmt; use zokrates_field::Field; type ProgramConstants<'ast, T> = - HashMap, TypedExpression<'ast, T>>>; + HashMap, TypedConstant<'ast, T>>>; #[derive(Debug, PartialEq)] pub enum Error { @@ -60,21 +60,25 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { fn get_constant( &self, id: &CanonicalConstantIdentifier<'ast>, - ) -> Option> { + ) -> Option> { self.constants .get(&id.module) - .and_then(|constants| constants.get(&id.id.into())) + .and_then(|constants| constants.get(&id.id)) .cloned() } - fn get_constant_for_identifier( - &self, - id: &Identifier<'ast>, - ) -> Option> { - self.constants - .get(&self.location) - .and_then(|constants| constants.get(&id)) - .cloned() + fn get_constant_for_identifier(&self, id: &Identifier<'ast>) -> Option> { + match &id.id { + // canonical constants can be accessed directly in the constant map + CoreIdentifier::Constant(c) => self.get_constant(c), + // source ids are checked against the canonical constant map, setting the module to the current module + CoreIdentifier::Source(id) => self + .constants + .get(&self.location) + .and_then(|constants| constants.get(id)) + .cloned(), + _ => unreachable!(), + } } } @@ -129,16 +133,16 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { self.get_constant(&imported_id).unwrap() } TypedConstantSymbol::Here(c) => { - let non_propagated_constant = fold_constant(self, c)?.expression; + let non_propagated_constant = fold_constant(self, c)?; // folding the constant above only reduces it to an expression containing only literals, not to a single literal. // propagating with an empty map of constants reduces it to a single literal Propagator::with_constants(&mut HashMap::default()) - .fold_expression(non_propagated_constant) + .fold_constant(non_propagated_constant) .unwrap() } }; - if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(*id.ty.clone()).unwrap() == constant.get_type() { + if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(constant.ty.clone()).unwrap() == constant.expression.get_type() { // add to the constant map. The value added is always a single litteral self.constants .get_mut(&self.location) @@ -147,12 +151,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { Ok(( id, - TypedConstantSymbol::Here(TypedConstant { - expression: constant, - }), + TypedConstantSymbol::Here(constant), )) } else { - Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant.get_type(), id.id, id.ty))) + Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant.expression.get_type(), id.id, constant.ty))) } }) .collect::, _>>()?, @@ -185,10 +187,13 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { }; Ok(DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() { - TypedExpression::Uint(UExpression { + TypedConstant { + expression: TypedExpression::Uint(UExpression { inner: UExpressionInner::Value(v), .. - }) => v as u32, + }), + ty: DeclarationType::Uint(UBitwidth::B32) + } => v as u32, _ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"), })) } @@ -309,14 +314,11 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new( - const_id, - "main".into(), + CanonicalConstantIdentifier::new(const_id, "main".into()), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))), DeclarationType::FieldElement, - ), - TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(1)), - ))), + )), )] .into_iter() .collect(); @@ -377,7 +379,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -400,10 +402,11 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new(const_id, "main".into(), DeclarationType::Boolean), - TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Boolean( - BooleanExpression::Value(true), - ))), + CanonicalConstantIdentifier::new(const_id, "main".into()), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Boolean(BooleanExpression::Value(true)), + DeclarationType::Boolean, + )), )] .into_iter() .collect(); @@ -464,7 +467,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -488,15 +491,12 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new( - const_id, - "main".into(), - DeclarationType::Uint(UBitwidth::B32), - ), + CanonicalConstantIdentifier::new(const_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( UExpressionInner::Value(1u128) .annotate(UBitwidth::B32) .into(), + DeclarationType::Uint(UBitwidth::B32), )), )] .into_iter() @@ -558,7 +558,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -592,24 +592,23 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new( - const_id, - "main".into(), + CanonicalConstantIdentifier::new(const_id, "main".into()), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Array( + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + ), DeclarationType::Array(DeclarationArrayType::new( DeclarationType::FieldElement, 2u32, )), - ), - TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::Array( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ] - .into(), - ) - .annotate(GType::FieldElement, 2usize), - ))), + )), )] .into_iter() .collect(); @@ -694,7 +693,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -735,23 +734,16 @@ mod tests { .collect(), constants: vec![ ( - CanonicalConstantIdentifier::new( - const_a_id, - "main".into(), - DeclarationType::FieldElement, - ), + CanonicalConstantIdentifier::new(const_a_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(1), )), + DeclarationType::FieldElement, )), ), ( - CanonicalConstantIdentifier::new( - const_b_id, - "main".into(), - DeclarationType::FieldElement, - ), + CanonicalConstantIdentifier::new(const_b_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Add( box FieldElementExpression::Identifier(Identifier::from( @@ -759,6 +751,7 @@ mod tests { )), box FieldElementExpression::Number(Bn128Field::from(1)), )), + DeclarationType::FieldElement, )), ), ] @@ -799,27 +792,21 @@ mod tests { .collect(), constants: vec![ ( - CanonicalConstantIdentifier::new( - const_a_id, - "main".into(), - DeclarationType::FieldElement, - ), + CanonicalConstantIdentifier::new(const_a_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(1), )), + DeclarationType::FieldElement, )), ), ( - CanonicalConstantIdentifier::new( - const_b_id, - "main".into(), - DeclarationType::FieldElement, - ), + CanonicalConstantIdentifier::new(const_b_id, "main".into()), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(2), )), + DeclarationType::FieldElement, )), ), ] @@ -831,7 +818,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -866,14 +853,13 @@ mod tests { .into_iter() .collect(), constants: vec![( - CanonicalConstantIdentifier::new( - foo_const_id, - "foo".into(), + CanonicalConstantIdentifier::new(foo_const_id, "foo".into()), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + )), DeclarationType::FieldElement, - ), - TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(42)), - ))), + )), )] .into_iter() .collect(), @@ -899,15 +885,10 @@ mod tests { .into_iter() .collect(), constants: vec![( - CanonicalConstantIdentifier::new( - foo_const_id, - "main".into(), - DeclarationType::FieldElement, - ), + CanonicalConstantIdentifier::new(foo_const_id, "main".into()), TypedConstantSymbol::There(CanonicalConstantIdentifier::new( foo_const_id, "foo".into(), - DeclarationType::FieldElement, )), )] .into_iter() @@ -945,14 +926,13 @@ mod tests { .into_iter() .collect(), constants: vec![( - CanonicalConstantIdentifier::new( - foo_const_id, - "main".into(), + CanonicalConstantIdentifier::new(foo_const_id, "main".into()), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + )), DeclarationType::FieldElement, - ), - TypedConstantSymbol::Here(TypedConstant::new(TypedExpression::FieldElement( - FieldElementExpression::Number(Bn128Field::from(42)), - ))), + )), )] .into_iter() .collect(), @@ -968,6 +948,6 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } } diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 6bb1cc197..cf2bb958e 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -232,7 +232,6 @@ pub trait Folder<'ast, T: Field>: Sized { CanonicalConstantIdentifier { module: self.fold_module_id(i.module), id: i.id, - ty: box self.fold_declaration_type(*i.ty), } } @@ -1056,6 +1055,7 @@ pub fn fold_constant<'ast, T: Field, F: Folder<'ast, T>>( ) -> TypedConstant<'ast, T> { TypedConstant { expression: f.fold_expression(c.expression), + ty: f.fold_declaration_type(c.ty), } } diff --git a/zokrates_core/src/typed_absy/identifier.rs b/zokrates_core/src/typed_absy/identifier.rs index 00bc7425d..6b2420749 100644 --- a/zokrates_core/src/typed_absy/identifier.rs +++ b/zokrates_core/src/typed_absy/identifier.rs @@ -1,3 +1,4 @@ +use crate::typed_absy::CanonicalConstantIdentifier; use std::convert::TryInto; use std::fmt; @@ -5,6 +6,7 @@ use std::fmt; pub enum CoreIdentifier<'ast> { Source(&'ast str), Call(usize), + Constant(CanonicalConstantIdentifier<'ast>), } impl<'ast> fmt::Display for CoreIdentifier<'ast> { @@ -12,6 +14,7 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { match self { CoreIdentifier::Source(s) => write!(f, "{}", s), CoreIdentifier::Call(i) => write!(f, "#CALL_RETURN_AT_INDEX_{}", i), + CoreIdentifier::Constant(c) => write!(f, "{}/{}", c.module.display(), c.id), } } } @@ -22,6 +25,12 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { } } +impl<'ast> From> for CoreIdentifier<'ast> { + fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> { + CoreIdentifier::Constant(s) + } +} + /// A identifier for a variable #[derive(Debug, PartialEq, Clone, Hash, Eq)] pub struct Identifier<'ast> { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index b31dfeee1..34ef85904 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -205,7 +205,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { .iter() .map(|(id, symbol)| match symbol { TypedConstantSymbol::Here(ref tc) => { - format!("const {} {} = {}", id.ty, id.id, tc) + format!("const {} {} = {}", tc.ty, id.id, tc.expression) } TypedConstantSymbol::There(ref imported_id) => { format!( @@ -312,11 +312,12 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { #[derive(Clone, PartialEq, Debug)] pub struct TypedConstant<'ast, T> { pub expression: TypedExpression<'ast, T>, + pub ty: DeclarationType<'ast>, } impl<'ast, T> TypedConstant<'ast, T> { - pub fn new(expression: TypedExpression<'ast, T>) -> Self { - TypedConstant { expression } + pub fn new(expression: TypedExpression<'ast, T>, ty: DeclarationType<'ast>) -> Self { + TypedConstant { expression, ty } } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index c85e97a7d..f19062446 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -121,7 +121,6 @@ pub trait ResultFolder<'ast, T: Field>: Sized { Ok(CanonicalConstantIdentifier { module: self.fold_module_id(i.module)?, id: i.id, - ty: box self.fold_declaration_type(*i.ty)?, }) } @@ -1110,6 +1109,7 @@ pub fn fold_constant<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result, F::Error> { Ok(TypedConstant { expression: f.fold_expression(c.expression)?, + ty: f.fold_declaration_type(c.ty)?, }) } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 24d125c4d..9d8268d57 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,4 +1,6 @@ -use crate::typed_absy::{Identifier, OwnedTypedModuleId, UExpression, UExpressionInner}; +use crate::typed_absy::{ + CoreIdentifier, Identifier, OwnedTypedModuleId, UExpression, UExpressionInner, +}; use crate::typed_absy::{TryFrom, TryInto}; use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; use std::collections::BTreeMap; @@ -107,20 +109,11 @@ pub type ConstantIdentifier<'ast> = &'ast str; pub struct CanonicalConstantIdentifier<'ast> { pub module: OwnedTypedModuleId, pub id: ConstantIdentifier<'ast>, - pub ty: Box>, } impl<'ast> CanonicalConstantIdentifier<'ast> { - pub fn new( - id: ConstantIdentifier<'ast>, - module: OwnedTypedModuleId, - ty: DeclarationType<'ast>, - ) -> Self { - CanonicalConstantIdentifier { - module, - id, - ty: box ty, - } + pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self { + CanonicalConstantIdentifier { module, id } } } @@ -970,12 +963,8 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( impl<'ast, T> From> for UExpression<'ast, T> { fn from(c: CanonicalConstantIdentifier<'ast>) -> Self { - let bitwidth = match *c.ty { - DeclarationType::Uint(bitwidth) => bitwidth, - _ => unreachable!(), - }; - - UExpressionInner::Identifier(Identifier::from(c.id)).annotate(bitwidth) + UExpressionInner::Identifier(Identifier::from(CoreIdentifier::Constant(c))) + .annotate(UBitwidth::B32) } } From f1ecdfdc86805cb02fe7fba226f342f3db9b6f40 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sat, 21 Aug 2021 00:15:39 +0200 Subject: [PATCH 12/78] clippy, changelog --- changelogs/unreleased/975-schaeff | 1 + zokrates_core/src/static_analysis/constant_inliner.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelogs/unreleased/975-schaeff diff --git a/changelogs/unreleased/975-schaeff b/changelogs/unreleased/975-schaeff new file mode 100644 index 000000000..d0073f3cb --- /dev/null +++ b/changelogs/unreleased/975-schaeff @@ -0,0 +1 @@ +Allow calls in constant definitions \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index a8423dda1..ff9c48eb9 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -147,7 +147,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { self.constants .get_mut(&self.location) .unwrap() - .insert(id.id.into(), constant.clone()); + .insert(id.id, constant.clone()); Ok(( id, From 7b5a973407fa1a9b3fc5d8ecfea651390542870c Mon Sep 17 00:00:00 2001 From: schaeff Date: Sat, 21 Aug 2021 00:16:39 +0200 Subject: [PATCH 13/78] changelog --- changelogs/unreleased/974-schaeff | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/974-schaeff diff --git a/changelogs/unreleased/974-schaeff b/changelogs/unreleased/974-schaeff new file mode 100644 index 000000000..0f2622f47 --- /dev/null +++ b/changelogs/unreleased/974-schaeff @@ -0,0 +1 @@ +Fail on mistyped constants \ No newline at end of file From e77d76a01f3b61830646c75ffe60304d29034518 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 22 Aug 2021 15:51:59 +0200 Subject: [PATCH 14/78] fix tests --- .../src/static_analysis/constant_inliner.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 1ad506e3e..589ace6fd 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -377,7 +377,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -464,7 +464,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -558,7 +558,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -694,7 +694,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -831,7 +831,7 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } #[test] @@ -968,6 +968,6 @@ mod tests { .collect(), }; - assert_eq!(program, expected_program) + assert_eq!(program, Ok(expected_program)) } } From 5a208beea762437864b00e6dc3180c4338fe3aa9 Mon Sep 17 00:00:00 2001 From: dark64 Date: Sun, 22 Aug 2021 23:22:15 +0200 Subject: [PATCH 15/78] gracefully handle unconstrained variables --- zokrates_core/src/compile.rs | 15 ++++--- zokrates_core/src/static_analysis/mod.rs | 27 ++++++----- .../src/static_analysis/unconstrained_vars.rs | 45 ++++++++++++------- 3 files changed, 54 insertions(+), 33 deletions(-) diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 172ad7473..8d5289aa0 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -189,15 +189,16 @@ pub fn compile>( ) -> Result, CompileErrors> { let arena = Arena::new(); - let (typed_ast, abi) = check_with_arena(source, location, resolver, config, &arena)?; + let (typed_ast, abi) = + check_with_arena(source, location.to_path_buf(), resolver, config, &arena)?; // flatten input program log::debug!("Flatten"); let program_flattened = Flattener::flatten(typed_ast, config); - // analyse (constant propagation after call resolution) - log::debug!("Analyse flat program"); - let program_flattened = program_flattened.analyse(); + // constant propagation after call resolution + log::debug!("Propagate flat program"); + let program_flattened = program_flattened.propagate(); // convert to ir log::debug!("Convert to IR"); @@ -207,9 +208,11 @@ pub fn compile>( log::debug!("Optimise IR"); let optimized_ir_prog = ir_prog.optimize(); - // analyse (check constraints) + // analyse ir (check constraints) log::debug!("Analyse IR"); - let optimized_ir_prog = optimized_ir_prog.analyse(); + let optimized_ir_prog = optimized_ir_prog + .analyse() + .map_err(|e| CompileErrorInner::from(e).in_file(location.as_path()))?; Ok(CompilationArtifacts { prog: optimized_ir_prog, diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 3b8751f1b..82cee5414 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -33,13 +33,18 @@ use std::fmt; use zokrates_field::Field; pub trait Analyse { - fn analyse(self) -> Self; + type Error; + + fn analyse(self) -> Result + where + Self: Sized; } #[derive(Debug)] pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), NonConstantArgument(self::constant_argument_checker::Error), + UnconstrainedVariable(self::unconstrained_vars::Error), } impl From for Error { @@ -60,12 +65,19 @@ impl From for Error { } } +impl From for Error { + fn from(e: unconstrained_vars::Error) -> Self { + Error::UnconstrainedVariable(e) + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::Reducer(e) => write!(f, "{}", e), Error::Propagation(e) => write!(f, "{}", e), Error::NonConstantArgument(e) => write!(f, "{}", e), + Error::UnconstrainedVariable(e) => write!(f, "{}", e), } } } @@ -126,16 +138,11 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { } } -impl Analyse for FlatProg { - fn analyse(self) -> Self { - log::debug!("Static analyser: Propagate flat"); - self.propagate() - } -} - impl Analyse for Prog { - fn analyse(self) -> Self { + type Error = Error; + + fn analyse(self) -> Result { log::debug!("Static analyser: Detect unconstrained zir"); - UnconstrainedVariableDetector::detect(self) + UnconstrainedVariableDetector::detect(self).map_err(|e| e.into()) } } diff --git a/zokrates_core/src/static_analysis/unconstrained_vars.rs b/zokrates_core/src/static_analysis/unconstrained_vars.rs index 2bcbbd050..5adf4ec49 100644 --- a/zokrates_core/src/static_analysis/unconstrained_vars.rs +++ b/zokrates_core/src/static_analysis/unconstrained_vars.rs @@ -3,6 +3,7 @@ use crate::ir::folder::Folder; use crate::ir::Directive; use crate::ir::Prog; use std::collections::HashSet; +use std::fmt; use zokrates_field::Field; #[derive(Debug)] @@ -10,6 +11,20 @@ pub struct UnconstrainedVariableDetector { pub(self) variables: HashSet, } +#[derive(Debug)] +pub struct Error(usize); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Found unconstrained variables during IR analysis (found {} occurrence{})", + self.0, + if self.0 == 1 { "" } else { "s" } + ) + } +} + impl UnconstrainedVariableDetector { fn new(p: &Prog) -> Self { UnconstrainedVariableDetector { @@ -21,22 +36,16 @@ impl UnconstrainedVariableDetector { .collect(), } } - pub fn detect(p: Prog) -> Prog { + + pub fn detect(p: Prog) -> Result, Error> { let mut instance = Self::new(&p); let p = instance.fold_module(p); - // we should probably handle this case instead of asserting at some point - assert!( - instance.variables.is_empty(), - "Unconstrained variables are not allowed (found {} occurrence{})", - instance.variables.len(), - if instance.variables.len() == 1 { - "" - } else { - "s" - } - ); - p + if instance.variables.is_empty() { + Ok(p) + } else { + Err(Error(instance.variables.len())) + } } } @@ -63,7 +72,6 @@ mod tests { use zokrates_field::Bn128Field; #[test] - #[should_panic] fn should_detect_unconstrained_private_input() { // def main(_0) -> (1): // (1 * ~one) * (42 * ~one) == 1 * ~out_0 @@ -92,7 +100,8 @@ mod tests { main, }; - UnconstrainedVariableDetector::detect(p); + let p = UnconstrainedVariableDetector::detect(p); + assert!(p.is_err()); } #[test] @@ -116,7 +125,8 @@ mod tests { main, }; - UnconstrainedVariableDetector::detect(p); + let p = UnconstrainedVariableDetector::detect(p); + assert!(p.is_ok()); } #[test] @@ -174,6 +184,7 @@ mod tests { main, }; - UnconstrainedVariableDetector::detect(p); + let p = UnconstrainedVariableDetector::detect(p); + assert!(p.is_ok()); } } From db33856ed5c78807123680232ef12c38bfa298da Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 00:56:45 +0200 Subject: [PATCH 16/78] fix test --- zokrates_cli/examples/compile_errors/out_of_bounds.zok | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/zokrates_cli/examples/compile_errors/out_of_bounds.zok b/zokrates_cli/examples/compile_errors/out_of_bounds.zok index f2922a4bc..d0d98e5e2 100644 --- a/zokrates_cli/examples/compile_errors/out_of_bounds.zok +++ b/zokrates_cli/examples/compile_errors/out_of_bounds.zok @@ -1,4 +1,4 @@ -def main() -> field: - field[10] a = [0; 10] - u32 index = if [0f] != [1f] then 1000 else 0 fi - return a[index] \ No newline at end of file +def main(field a, field b) -> field: + field[10] arr = [0; 10] + u32 index = if [a, 1] != [b, 0] then 1000 else 0 fi + return arr[index] \ No newline at end of file From ce158ca5dce198414d308a003fd4ca3c93d3d8de Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 23 Aug 2021 13:10:23 +0200 Subject: [PATCH 17/78] add breaking test when propagation is not powerful enough --- .../constant_reduction_fail.zok | 7 ++++++ .../src/static_analysis/constant_inliner.rs | 22 +++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/constant_reduction_fail.zok diff --git a/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok b/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok new file mode 100644 index 000000000..d1102f745 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok @@ -0,0 +1,7 @@ +def constant() -> u32: + return 3 + +const u32 CONSTANT = constant() + +def main(field[CONSTANT] a): + return \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index ff9c48eb9..6d5220d34 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -7,18 +7,23 @@ use std::convert::TryInto; use std::fmt; use zokrates_field::Field; +// a map of the constants in this program +// the values are constants whose expression does not include any identifier. It does not have to be a single literal, as +// we keep function calls here to be inlined later type ProgramConstants<'ast, T> = HashMap, TypedConstant<'ast, T>>>; #[derive(Debug, PartialEq)] pub enum Error { Type(String), + Propagation(String), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::Type(s) => write!(f, "{}", s), + Error::Propagation(s) => write!(f, "{}", s), } } } @@ -129,13 +134,13 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { // visit the imported symbol. This triggers visiting the corresponding module if needed let imported_id = self.fold_canonical_constant_identifier(imported_id)?; // after that, the constant must have been defined defined in the global map. It is already reduced - // to a literal, so running propagation isn't required + // to the maximum, so running propagation isn't required self.get_constant(&imported_id).unwrap() } TypedConstantSymbol::Here(c) => { let non_propagated_constant = fold_constant(self, c)?; // folding the constant above only reduces it to an expression containing only literals, not to a single literal. - // propagating with an empty map of constants reduces it to a single literal + // propagating with an empty map of constants reduces it to the maximum Propagator::with_constants(&mut HashMap::default()) .fold_constant(non_propagated_constant) .unwrap() @@ -143,7 +148,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { }; if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(constant.ty.clone()).unwrap() == constant.expression.get_type() { - // add to the constant map. The value added is always a single litteral + // add to the constant map self.constants .get_mut(&self.location) .unwrap() @@ -167,9 +172,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { self.fold_function_symbol(fun)?, )) }) - .collect::, _>>() + .collect::, _>>()? .into_iter() - .flatten() .collect(), }) } @@ -186,16 +190,16 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { ..id }; - Ok(DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() { + match self.get_constant(&id).unwrap() { TypedConstant { expression: TypedExpression::Uint(UExpression { inner: UExpressionInner::Value(v), .. }), ty: DeclarationType::Uint(UBitwidth::B32) - } => v as u32, - _ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"), - })) + } => Ok(DeclarationConstant::Concrete(v as u32)), + c => Err(Error::Propagation(format!("Failed to reduce `{}` to a single u32 literal, try avoiding function calls in the definition of `{}` in {}", c, id.id, id.module.display()))) + } } c => Ok(c), } From 02873b9a78e64bb2c7e0c556ebea3b914b2a8090 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 23 Aug 2021 13:11:42 +0200 Subject: [PATCH 18/78] remove comment --- zokrates_core/src/semantics.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 28cfb1714..37ed7510a 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -501,9 +501,6 @@ impl<'ast, T: Field> Checker<'ast, T> { ))) } - // pb: we convert canonical constants into identifiers inside rhs of constant definitions, loosing the module they are from. but then we want to reduce them to literals, which requires knowing which module they come - // if we don't convert, we end up with a new type of core identifier (implemented now) which confuses propagation because they are not equal to the identifier of the same name - fn check_symbol_declaration( &mut self, declaration: SymbolDeclarationNode<'ast>, From ed97e815c71f61e88cb484fb33fb2563cdf4b8a5 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 15:21:53 +0200 Subject: [PATCH 19/78] refactor flatten_embed_call --- zokrates_core/src/flatten/mod.rs | 117 ++++++++++--------------------- zokrates_core/src/zir/uint.rs | 8 +++ 2 files changed, 46 insertions(+), 79 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index dffe54325..ca916636b 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1,5 +1,3 @@ -#![allow(clippy::needless_collect)] - //! Module containing the `Flattener` to process a program that is R1CS-able. //! //! @file flatten.rs @@ -1082,38 +1080,24 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } - fn flatten_u_to_bits( + fn u_to_bits( &mut self, - statements_flattened: &mut FlatStatements, - expression: ZirExpression<'ast, T>, + expression: FlatUExpression, bitwidth: UBitwidth, ) -> Vec> { - let expression = UExpression::try_from(expression).unwrap(); - let metadata = expression.metadata.as_ref().unwrap(); - assert!(!metadata.should_reduce.is_unknown()); + let bits = expression.bits.unwrap(); + assert_eq!(bits.len(), bitwidth.to_usize()); - let from = metadata.bitwidth(); - let p = self.flatten_uint_expression(statements_flattened, expression); - self.get_bits(&p, from as usize, bitwidth, statements_flattened) - .into_iter() - .map(FlatUExpression::with_field) - .collect() + bits.into_iter().map(FlatUExpression::with_field).collect() } - fn flatten_bits_to_u( + fn bits_to_u( &mut self, - statements_flattened: &mut FlatStatements, - bits: Vec>, + bits: Vec>, bitwidth: UBitwidth, ) -> FlatUExpression { + let bits: Vec<_> = bits.into_iter().map(|e| e.get_field_unchecked()).collect(); assert_eq!(bits.len(), bitwidth.to_usize()); - let bits: Vec<_> = bits - .into_iter() - .map(|p| { - self.flatten_expression(statements_flattened, p) - .get_field_unchecked() - }) - .collect(); FlatUExpression::with_bits(bits) } @@ -1131,70 +1115,55 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements, embed: FlatEmbed, generics: Vec, - mut param_expressions: Vec>, + param_expressions: Vec>, ) -> Vec> { + let mut params: Vec<_> = param_expressions + .into_iter() + .map(|p| { + if let ZirExpression::Uint(e) = &p { + assert!(e.metadata.as_ref().unwrap().should_reduce.is_true()); + } + self.flatten_expression(statements_flattened, p) + }) + .collect(); + match embed { - crate::embed::FlatEmbed::U64ToBits => self.flatten_u_to_bits( - statements_flattened, - param_expressions.pop().unwrap(), - 64.into(), - ), - crate::embed::FlatEmbed::U32ToBits => self.flatten_u_to_bits( - statements_flattened, - param_expressions.pop().unwrap(), - 32.into(), - ), - crate::embed::FlatEmbed::U16ToBits => self.flatten_u_to_bits( - statements_flattened, - param_expressions.pop().unwrap(), - 16.into(), - ), - crate::embed::FlatEmbed::U8ToBits => self.flatten_u_to_bits( - statements_flattened, - param_expressions.pop().unwrap(), - 8.into(), - ), - crate::embed::FlatEmbed::U64FromBits => { - vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 64.into())] + FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()), + FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()), + FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()), + FlatEmbed::U64ToBits => self.u_to_bits(params.pop().unwrap(), 64.into()), + FlatEmbed::U8FromBits => { + vec![self.bits_to_u(params, 8.into())] } - crate::embed::FlatEmbed::U32FromBits => { - vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 32.into())] + FlatEmbed::U16FromBits => { + vec![self.bits_to_u(params, 16.into())] } - crate::embed::FlatEmbed::U16FromBits => { - vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 16.into())] + FlatEmbed::U32FromBits => { + vec![self.bits_to_u(params, 32.into())] } - crate::embed::FlatEmbed::U8FromBits => { - vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] + FlatEmbed::U64FromBits => { + vec![self.bits_to_u(params, 64.into())] } - crate::embed::FlatEmbed::BitArrayLe => { + FlatEmbed::BitArrayLe => { // get the length of the bit arrays let len = generics[0]; // split the arguments into the two bit arrays of size `len` let (expressions, constants) = ( - param_expressions[..len as usize].to_vec(), - param_expressions[len as usize..].to_vec(), + params[..len as usize].to_vec(), + params[len as usize..].to_vec(), ); // define variables for the variable bits let variables: Vec<_> = expressions .into_iter() - .map(|e| { - let e = self - .flatten_expression(statements_flattened, e) - .get_field_unchecked(); - self.define(e, statements_flattened) - }) + .map(|e| self.define(e.get_field_unchecked(), statements_flattened)) .collect(); // get constants for the constant bits let constants: Vec<_> = constants .into_iter() - .map(|e| { - self.flatten_expression(statements_flattened, e) - .get_field_unchecked() - }) - .map(|e| match e { + .map(|e| match e.get_field_unchecked() { FlatExpression::Number(n) if n == T::one() => true, FlatExpression::Number(n) if n == T::zero() => false, _ => unreachable!(), @@ -1225,19 +1194,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Handle complex parameters and assign values: // Rename Parameters, assign them to values in call. Resolve complex expressions with definitions - // Clippy doesn't like the fact that we're collecting here, however not doing so leads to a borrow issue - // of `self` in the for-loop just after. This is why the `needless_collect` lint is disabled for this file - // (it does not work for this single line) - let params_flattened = param_expressions - .into_iter() - .map(|param_expr| self.flatten_expression(statements_flattened, param_expr)) - .into_iter() - .map(|x| x.get_field_unchecked()) - .collect::>(); + let params_flattened = params.into_iter().map(|e| e.get_field_unchecked()); - for (concrete_argument, formal_argument) in - params_flattened.into_iter().zip(funct.arguments) - { + for (concrete_argument, formal_argument) in params_flattened.zip(funct.arguments) { let new_var = self.define(concrete_argument, statements_flattened); replacement_map.insert(formal_argument.id, new_var); } diff --git a/zokrates_core/src/zir/uint.rs b/zokrates_core/src/zir/uint.rs index ac944ca2d..3e3b51a26 100644 --- a/zokrates_core/src/zir/uint.rs +++ b/zokrates_core/src/zir/uint.rs @@ -113,6 +113,14 @@ impl ShouldReduce { *self == ShouldReduce::Unknown } + pub fn is_true(&self) -> bool { + *self == ShouldReduce::True + } + + pub fn is_false(&self) -> bool { + *self == ShouldReduce::False + } + // we can always enable a reduction pub fn make_true(self) -> Self { ShouldReduce::True From 389d7d0a040b3b64c53a622046da83e11a4fb8fb Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 15:47:28 +0200 Subject: [PATCH 20/78] fix warnings --- zokrates_core/src/compile.rs | 3 +-- zokrates_core/src/static_analysis/mod.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 8d5289aa0..fc721cfe4 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -189,8 +189,7 @@ pub fn compile>( ) -> Result, CompileErrors> { let arena = Arena::new(); - let (typed_ast, abi) = - check_with_arena(source, location.to_path_buf(), resolver, config, &arena)?; + let (typed_ast, abi) = check_with_arena(source, location.clone(), resolver, config, &arena)?; // flatten input program log::debug!("Flatten"); diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 82cee5414..d2ce4ea2b 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -24,7 +24,6 @@ use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; use crate::compile::CompileConfig; -use crate::flat_absy::FlatProg; use crate::ir::Prog; use crate::static_analysis::constant_inliner::ConstantInliner; use crate::typed_absy::{abi::Abi, TypedProgram}; @@ -143,6 +142,6 @@ impl Analyse for Prog { fn analyse(self) -> Result { log::debug!("Static analyser: Detect unconstrained zir"); - UnconstrainedVariableDetector::detect(self).map_err(|e| e.into()) + UnconstrainedVariableDetector::detect(self).map_err(Error::from) } } From b7a035025d53e7e3f97facc6a14423732c12ac34 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 15:51:56 +0200 Subject: [PATCH 21/78] add test for unconstrained input --- zokrates_cli/examples/compile_errors/unconstrained_input.zok | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 zokrates_cli/examples/compile_errors/unconstrained_input.zok diff --git a/zokrates_cli/examples/compile_errors/unconstrained_input.zok b/zokrates_cli/examples/compile_errors/unconstrained_input.zok new file mode 100644 index 000000000..0bb46addc --- /dev/null +++ b/zokrates_cli/examples/compile_errors/unconstrained_input.zok @@ -0,0 +1,2 @@ +def main(private field a) -> field: + return 1 \ No newline at end of file From 300fd089663a33430923bbaa97d1877c7950b56e Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 23 Aug 2021 19:33:59 +0200 Subject: [PATCH 22/78] remove unneeded annotation --- zokrates_core/src/static_analysis/constant_inliner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 6d5220d34..c318c9c91 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -166,7 +166,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { functions: m .functions .into_iter() - .map::, _>(|(key, fun)| { + .map(|(key, fun)| { Ok(( self.fold_declaration_function_key(key)?, self.fold_function_symbol(fun)?, From 3235cf4df9666b5ad436d565ef176d80c9bbf7e9 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 21:23:44 +0200 Subject: [PATCH 23/78] fix div by zero test --- zokrates_cli/examples/compile_errors/div_by_zero.zok | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_cli/examples/compile_errors/div_by_zero.zok b/zokrates_cli/examples/compile_errors/div_by_zero.zok index 733ddaa9b..784a9aec0 100644 --- a/zokrates_cli/examples/compile_errors/div_by_zero.zok +++ b/zokrates_cli/examples/compile_errors/div_by_zero.zok @@ -1,3 +1,3 @@ def main(field input) -> field: - field divisor = if true then 0 else 1 fi + field divisor = if [input, 0] != [input, 1] then 0 else 1 fi return input / divisor \ No newline at end of file From 13ebd1e4d75f162447e28ce49531f0a3e5a359af Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 21:42:09 +0200 Subject: [PATCH 24/78] use visitor instead of folding --- .../src/static_analysis/unconstrained_vars.rs | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/zokrates_core/src/static_analysis/unconstrained_vars.rs b/zokrates_core/src/static_analysis/unconstrained_vars.rs index 5adf4ec49..5eb0b80eb 100644 --- a/zokrates_core/src/static_analysis/unconstrained_vars.rs +++ b/zokrates_core/src/static_analysis/unconstrained_vars.rs @@ -1,5 +1,5 @@ use crate::flat_absy::FlatVariable; -use crate::ir::folder::Folder; +use crate::ir::visitor::Visitor; use crate::ir::Directive; use crate::ir::Prog; use std::collections::HashSet; @@ -39,7 +39,7 @@ impl UnconstrainedVariableDetector { pub fn detect(p: Prog) -> Result, Error> { let mut instance = Self::new(&p); - let p = instance.fold_module(p); + instance.visit_module(&p); if instance.variables.is_empty() { Ok(p) @@ -49,17 +49,13 @@ impl UnconstrainedVariableDetector { } } -impl Folder for UnconstrainedVariableDetector { - fn fold_argument(&mut self, p: FlatVariable) -> FlatVariable { - p +impl Visitor for UnconstrainedVariableDetector { + fn visit_argument(&mut self, _: &FlatVariable) {} + fn visit_variable(&mut self, v: &FlatVariable) { + self.variables.remove(v); } - fn fold_variable(&mut self, v: FlatVariable) -> FlatVariable { - self.variables.remove(&v); - v - } - fn fold_directive(&mut self, d: Directive) -> Directive { + fn visit_directive(&mut self, d: &Directive) { self.variables.extend(d.outputs.iter()); - d } } From 53e139da7859a7c97b1e98fea2d5d4bf817dfe22 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 23 Aug 2021 21:47:23 +0200 Subject: [PATCH 25/78] add changelog --- changelogs/unreleased/977-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/977-dark64 diff --git a/changelogs/unreleased/977-dark64 b/changelogs/unreleased/977-dark64 new file mode 100644 index 000000000..c29eed7dc --- /dev/null +++ b/changelogs/unreleased/977-dark64 @@ -0,0 +1 @@ +Graceful error handling on unconstrained variable detection \ No newline at end of file From 88197a57d622eb770d5ee229cb66d1dd61523302 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 27 Aug 2021 15:57:14 +0200 Subject: [PATCH 26/78] refactor to keep constants until reduction, then try to inline their definitions --- .../{compile_errors => }/call_in_constant.zok | 2 +- .../constant_reduction_fail.zok | 9 +- .../examples/complex_call_in_constant.zok | 14 + zokrates_core/src/absy/node.rs | 13 +- zokrates_core/src/embed.rs | 147 ++++++++++- zokrates_core/src/parser/tokenize/position.rs | 1 - zokrates_core/src/semantics.rs | 242 ++++++++++-------- .../src/static_analysis/constant_inliner.rs | 229 ++++++++--------- .../static_analysis/flatten_complex_types.rs | 4 +- .../src/static_analysis/propagation.rs | 3 +- .../src/static_analysis/reducer/inline.rs | 15 +- .../src/static_analysis/reducer/mod.rs | 177 ++++++++++++- zokrates_core/src/typed_absy/folder.rs | 51 ++-- zokrates_core/src/typed_absy/identifier.rs | 4 +- zokrates_core/src/typed_absy/integer.rs | 12 +- zokrates_core/src/typed_absy/mod.rs | 96 +++---- zokrates_core/src/typed_absy/parameter.rs | 6 +- zokrates_core/src/typed_absy/result_folder.rs | 57 +++-- zokrates_core/src/typed_absy/types.rs | 107 ++++---- zokrates_core/src/typed_absy/uint.rs | 6 +- zokrates_core/src/typed_absy/variable.rs | 4 +- 21 files changed, 779 insertions(+), 420 deletions(-) rename zokrates_cli/examples/{compile_errors => }/call_in_constant.zok (84%) create mode 100644 zokrates_cli/examples/complex_call_in_constant.zok diff --git a/zokrates_cli/examples/compile_errors/call_in_constant.zok b/zokrates_cli/examples/call_in_constant.zok similarity index 84% rename from zokrates_cli/examples/compile_errors/call_in_constant.zok rename to zokrates_cli/examples/call_in_constant.zok index 5f18f0d98..b334ed260 100644 --- a/zokrates_cli/examples/compile_errors/call_in_constant.zok +++ b/zokrates_cli/examples/call_in_constant.zok @@ -3,7 +3,7 @@ def yes() -> bool: return true -const TRUE = yes() +const bool TRUE = yes() def main(): return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok b/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok index d1102f745..1f08bc326 100644 --- a/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok +++ b/zokrates_cli/examples/compile_errors/constant_reduction_fail.zok @@ -1,7 +1,6 @@ -def constant() -> u32: - return 3 +from "EMBED" import bit_array_le -const u32 CONSTANT = constant() +const bool CONST = bit_array_le([true], [true]) -def main(field[CONSTANT] a): - return \ No newline at end of file +def main() -> bool: + return CONST \ No newline at end of file diff --git a/zokrates_cli/examples/complex_call_in_constant.zok b/zokrates_cli/examples/complex_call_in_constant.zok new file mode 100644 index 000000000..f988ac237 --- /dev/null +++ b/zokrates_cli/examples/complex_call_in_constant.zok @@ -0,0 +1,14 @@ +def constant() -> u32: + u32 res = 0 + u32 x = 3 + for u32 y in 0..x do + res = res + 1 + endfor + return res + +const u32 CONSTANT = 1 + constant() + +const u32 OTHER_CONSTANT = 42 + +def main(field[CONSTANT] a) -> u32: + return CONSTANT + OTHER_CONSTANT \ No newline at end of file diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index 8a5745c0d..e8325d6d1 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -9,6 +9,16 @@ pub struct Node { pub value: T, } +impl Node { + pub fn mock(e: T) -> Self { + Self { + start: Position::mock(), + end: Position::mock(), + value: e, + } + } +} + impl Node { pub fn pos(&self) -> (Position, Position) { (self.start, self.end) @@ -67,8 +77,7 @@ pub trait NodeValue: fmt::Display + fmt::Debug + Sized + PartialEq { impl From for Node { fn from(v: V) -> Node { - let mock_position = Position { col: 42, line: 42 }; - Node::new(mock_position, mock_position, v) + Node::new(Position::mock(), Position::mock(), v) } } diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 3ee07660d..de10d1e0c 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -1,3 +1,7 @@ +use crate::absy::{ + types::{UnresolvedSignature, UnresolvedType}, + ConstantGenericNode, Expression, +}; use crate::flat_absy::{ FlatDirective, FlatExpression, FlatExpressionList, FlatFunction, FlatParameter, FlatStatement, FlatVariable, RuntimeError, @@ -26,7 +30,7 @@ cfg_if::cfg_if! { /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until /// the flattening step when it can be inlined. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] pub enum FlatEmbed { BitArrayLe, U32ToField, @@ -46,7 +50,134 @@ pub enum FlatEmbed { } impl FlatEmbed { - pub fn signature(&self) -> DeclarationSignature<'static> { + pub fn signature(&self) -> UnresolvedSignature { + match self { + FlatEmbed::BitArrayLe => UnresolvedSignature::new() + .generics(vec![ConstantGenericNode::mock("N")]) + .inputs(vec![ + UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::Identifier("N").into(), + ) + .into(), + UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::Identifier("N").into(), + ) + .into(), + ]) + .outputs(vec![UnresolvedType::Boolean.into()]), + FlatEmbed::U32ToField => UnresolvedSignature::new() + .inputs(vec![UnresolvedType::Uint(32).into()]) + .outputs(vec![UnresolvedType::FieldElement.into()]), + FlatEmbed::Unpack => UnresolvedSignature::new() + .generics(vec!["N".into()]) + .inputs(vec![UnresolvedType::FieldElement.into()]) + .outputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::Identifier("N").into(), + ) + .into()]), + FlatEmbed::U8ToBits => UnresolvedSignature::new() + .inputs(vec![UnresolvedType::Uint(8).into()]) + .outputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(8).into(), + ) + .into()]), + FlatEmbed::U16ToBits => UnresolvedSignature::new() + .inputs(vec![UnresolvedType::Uint(16).into()]) + .outputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(16).into(), + ) + .into()]), + FlatEmbed::U32ToBits => UnresolvedSignature::new() + .inputs(vec![UnresolvedType::Uint(32).into()]) + .outputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(32).into(), + ) + .into()]), + FlatEmbed::U64ToBits => UnresolvedSignature::new() + .inputs(vec![UnresolvedType::Uint(64).into()]) + .outputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(64).into(), + ) + .into()]), + FlatEmbed::U8FromBits => UnresolvedSignature::new() + .outputs(vec![UnresolvedType::Uint(8).into()]) + .inputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(8).into(), + ) + .into()]), + FlatEmbed::U16FromBits => UnresolvedSignature::new() + .outputs(vec![UnresolvedType::Uint(16).into()]) + .inputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(16).into(), + ) + .into()]), + FlatEmbed::U32FromBits => UnresolvedSignature::new() + .outputs(vec![UnresolvedType::Uint(32).into()]) + .inputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(32).into(), + ) + .into()]), + FlatEmbed::U64FromBits => UnresolvedSignature::new() + .outputs(vec![UnresolvedType::Uint(64).into()]) + .inputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(64).into(), + ) + .into()]), + #[cfg(feature = "bellman")] + FlatEmbed::Sha256Round => UnresolvedSignature::new() + .inputs(vec![ + UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(512).into(), + ) + .into(), + UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(256).into(), + ) + .into(), + ]) + .outputs(vec![UnresolvedType::array( + UnresolvedType::Boolean.into(), + Expression::U32Constant(256).into(), + ) + .into()]), + #[cfg(feature = "ark")] + FlatEmbed::SnarkVerifyBls12377 => UnresolvedSignature::new() + .generics(vec!["N".into(), "V".into()]) + .inputs(vec![ + UnresolvedType::array( + UnresolvedType::FieldElement.into(), + Expression::Identifier("N").into(), + ) + .into(), // inputs + UnresolvedType::array( + UnresolvedType::FieldElement.into(), + Expression::U32Constant(8).into(), + ) + .into(), // proof + UnresolvedType::array( + UnresolvedType::FieldElement.into(), + Expression::Identifier("V").into(), + ) + .into(), // 18 + (2 * n) // vk + ]) + .outputs(vec![UnresolvedType::Boolean.into()]), + } + } + + pub fn typed_signature(&self) -> DeclarationSignature<'static, T> { match self { FlatEmbed::BitArrayLe => DeclarationSignature::new() .generics(vec![Some(DeclarationConstant::Generic( @@ -181,15 +312,13 @@ impl FlatEmbed { } } - pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec { - let gen = self - .signature() - .generics - .into_iter() - .map(|c| match c.unwrap() { + pub fn generics<'ast, T>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec { + let gen = self.typed_signature().generics.into_iter().map( + |c: Option>| match c.unwrap() { DeclarationConstant::Generic(g) => g, _ => unreachable!(), - }); + }, + ); assert_eq!(gen.len(), assignment.0.len()); gen.map(|g| *assignment.0.get(&g).unwrap() as u32).collect() diff --git a/zokrates_core/src/parser/tokenize/position.rs b/zokrates_core/src/parser/tokenize/position.rs index a85ed3ce5..7b1e3b74e 100644 --- a/zokrates_core/src/parser/tokenize/position.rs +++ b/zokrates_core/src/parser/tokenize/position.rs @@ -15,7 +15,6 @@ impl Position { } } - #[cfg(test)] pub fn mock() -> Self { Position { line: 42, col: 42 } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 37ed7510a..9353ce7de 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -10,7 +10,7 @@ use crate::typed_absy::types::GGenericsAssignment; use crate::typed_absy::*; use crate::typed_absy::{DeclarationParameter, DeclarationVariable, Variable}; use num_bigint::BigUint; -use std::collections::{hash_map::Entry, BTreeSet, HashMap, HashSet}; +use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt; use std::path::PathBuf; use zokrates_field::Field; @@ -55,9 +55,9 @@ impl ErrorInner { } } -type TypeMap<'ast> = HashMap>>; -type ConstantMap<'ast> = - HashMap, DeclarationType<'ast>>>; +type TypeMap<'ast, T> = BTreeMap>>; +type ConstantMap<'ast, T> = + BTreeMap, DeclarationType<'ast, T>>>; /// The global state of the program during semantic checks #[derive(Debug)] @@ -67,26 +67,26 @@ struct State<'ast, T> { /// The already checked modules, which we're returning at the end typed_modules: TypedModules<'ast, T>, /// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof - types: TypeMap<'ast>, + types: TypeMap<'ast, T>, // The user-defined constants - constants: ConstantMap<'ast>, + constants: ConstantMap<'ast, T>, } /// A symbol for a given name: either a type or a group of functions. Not both! #[derive(PartialEq, Hash, Eq, Debug)] -enum SymbolType<'ast> { +enum SymbolType<'ast, T> { Type, Constant, - Functions(BTreeSet>), + Functions(BTreeSet>), } /// A data structure to keep track of all symbols in a module #[derive(Default)] -struct SymbolUnifier<'ast> { - symbols: HashMap>, +struct SymbolUnifier<'ast, T> { + symbols: BTreeMap>, } -impl<'ast> SymbolUnifier<'ast> { +impl<'ast, T: std::cmp::Ord> SymbolUnifier<'ast, T> { fn insert_type>(&mut self, id: S) -> bool { let e = self.symbols.entry(id.into()); match e { @@ -116,7 +116,7 @@ impl<'ast> SymbolUnifier<'ast> { fn insert_function>( &mut self, id: S, - signature: DeclarationSignature<'ast>, + signature: DeclarationSignature<'ast, T>, ) -> bool { let s_type = self.symbols.entry(id.into()); match s_type { @@ -142,9 +142,9 @@ impl<'ast, T: Field> State<'ast, T> { fn new(modules: Modules<'ast>) -> Self { State { modules, - typed_modules: HashMap::new(), - types: HashMap::new(), - constants: HashMap::new(), + typed_modules: BTreeMap::new(), + types: BTreeMap::new(), + constants: BTreeMap::new(), } } } @@ -225,7 +225,7 @@ impl<'ast, T: Field> FunctionQuery<'ast, T> { } /// match a `FunctionKey` against this `FunctionQuery` - fn match_func(&self, func: &DeclarationFunctionKey<'ast>) -> bool { + fn match_func(&self, func: &DeclarationFunctionKey<'ast, T>) -> bool { self.id == func.id && self.generics_count.map(|count| count == func.signature.generics.len()).unwrap_or(true) // we do not look at the values here, this will be checked when inlining anyway && self.inputs.len() == func.signature.inputs.len() @@ -249,8 +249,8 @@ impl<'ast, T: Field> FunctionQuery<'ast, T> { fn match_funcs( &self, - funcs: &HashSet>, - ) -> Vec> { + funcs: &HashSet>, + ) -> Vec> { funcs .iter() .filter(|func| self.match_func(func)) @@ -261,57 +261,68 @@ impl<'ast, T: Field> FunctionQuery<'ast, T> { /// A scoped variable, so that we can delete all variables of a given scope when exiting it #[derive(Clone, Debug)] -pub struct ScopedVariable<'ast, T> { - id: Variable<'ast, T>, +pub struct ScopedIdentifier<'ast> { + id: CoreIdentifier<'ast>, level: usize, } -impl<'ast, T> ScopedVariable<'ast, T> { +impl<'ast> ScopedIdentifier<'ast> { fn is_constant(&self) -> bool { self.level == 0 } } -/// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent -impl<'ast, T> PartialEq for ScopedVariable<'ast, T> { +/// Identifiers coming from constants or variables are equivalent +impl<'ast> PartialEq for ScopedIdentifier<'ast> { fn eq(&self, other: &Self) -> bool { - self.id.id == other.id.id + let i0 = match &self.id { + CoreIdentifier::Source(id) => id, + CoreIdentifier::Constant(c) => c.id, + _ => unreachable!(), + }; + + let i1 = match &other.id { + CoreIdentifier::Source(id) => id, + CoreIdentifier::Constant(c) => c.id, + _ => unreachable!(), + }; + + i0 == i1 } } -impl<'ast, T> Hash for ScopedVariable<'ast, T> { +/// Identifiers coming from constants or variables hash to the same result +impl<'ast> Hash for ScopedIdentifier<'ast> { fn hash(&self, state: &mut H) { - self.id.id.hash(state); + match &self.id { + CoreIdentifier::Source(id) => id.hash(state), + CoreIdentifier::Constant(c) => c.id.hash(state), + _ => unreachable!(), + } } } -impl<'ast, T> Eq for ScopedVariable<'ast, T> {} +impl<'ast> Eq for ScopedIdentifier<'ast> {} + +type Scope<'ast, T> = HashMap, Type<'ast, T>>; /// Checker checks the semantics of a program, keeping track of functions and variables in scope +#[derive(Default)] pub struct Checker<'ast, T> { - return_types: Option>>, - scope: HashSet>, - functions: HashSet>, + return_types: Option>>, + scope: Scope<'ast, T>, + functions: HashSet>, level: usize, } impl<'ast, T: Field> Checker<'ast, T> { - fn new() -> Self { - Checker { - return_types: None, - scope: HashSet::new(), - functions: HashSet::new(), - level: 0, - } - } - /// Check a `Program` /// /// # Arguments /// /// * `prog` - The `Program` to be checked pub fn check(prog: Program<'ast>) -> Result, Vec> { - Checker::new().check_program(prog) + Checker::default().check_program(prog) } fn check_program( @@ -362,7 +373,7 @@ impl<'ast, T: Field> Checker<'ast, T> { c.value.ty.clone(), module_id, &state, - &HashMap::default(), + &BTreeMap::default(), &mut HashSet::default(), )?; let checked_expr = @@ -406,7 +417,7 @@ impl<'ast, T: Field> Checker<'ast, T> { s: StructDefinitionNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result, Vec> { + ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -415,7 +426,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let mut fields_set = HashSet::new(); let mut generics = vec![]; - let mut generics_map = HashMap::new(); + let mut generics_map = BTreeMap::new(); for (index, g) in s.generics.iter().enumerate() { if state @@ -508,7 +519,7 @@ impl<'ast, T: Field> Checker<'ast, T> { state: &mut State<'ast, T>, functions: &mut TypedFunctionSymbols<'ast, T>, constants: &mut TypedConstantSymbols<'ast, T>, - symbol_unifier: &mut SymbolUnifier<'ast>, + symbol_unifier: &mut SymbolUnifier<'ast, T>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -575,7 +586,10 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedConstantSymbol::Here(c.clone()), )); self.insert_into_scope(Variable::with_id_and_type( - declaration.id, + CoreIdentifier::Constant(CanonicalConstantIdentifier::new( + declaration.id, + module_id.into(), + )), c.get_type(), )); assert!(state @@ -636,7 +650,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let pos = import.pos(); let import = import.value; - match Checker::new().check_module(&import.module_id, state) { + match Checker::default().check_module(&import.module_id, state) { Ok(()) => { // find candidates in the checked module let function_candidates: Vec<_> = state @@ -779,7 +793,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }; } Symbol::Flat(funct) => { - match symbol_unifier.insert_function(declaration.id, funct.signature()) { + match symbol_unifier.insert_function(declaration.id, funct.typed_signature()) { false => { errors.push( ErrorInner { @@ -797,11 +811,11 @@ impl<'ast, T: Field> Checker<'ast, T> { self.functions.insert( DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) - .signature(funct.signature()), + .signature(funct.typed_signature()), ); functions.insert( DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) - .signature(funct.signature()), + .signature(funct.typed_signature()), TypedFunctionSymbol::Flat(funct), ); } @@ -839,6 +853,7 @@ impl<'ast, T: Field> Checker<'ast, T> { // we go through symbol declarations and check them for declaration in module.symbols { + println!("{:#?}", self.scope); self.check_symbol_declaration( declaration, module_id, @@ -1044,13 +1059,13 @@ impl<'ast, T: Field> Checker<'ast, T> { signature: UnresolvedSignature<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - ) -> Result, Vec> { + ) -> Result, Vec> { let mut errors = vec![]; let mut inputs = vec![]; let mut outputs = vec![]; let mut generics = vec![]; - let mut generics_map = HashMap::new(); + let mut generics_map = BTreeMap::new(); for (index, g) in signature.generics.iter().enumerate() { if state @@ -1135,7 +1150,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { let pos = ty.pos(); let ty = ty.value; @@ -1271,17 +1286,17 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, expr: ExpressionNode<'ast>, module_id: &ModuleId, - constants_map: &HashMap, DeclarationType<'ast>>, - generics_map: &HashMap, usize>, + constants_map: &BTreeMap, DeclarationType<'ast, T>>, + generics_map: &BTreeMap, usize>, used_generics: &mut HashSet>, - ) -> Result, ErrorInner> { + ) -> Result, ErrorInner> { let pos = expr.pos(); match expr.value { - Expression::U32Constant(c) => Ok(DeclarationConstant::Concrete(c)), + Expression::U32Constant(c) => Ok(DeclarationConstant::from(c)), Expression::IntConstant(c) => { if c <= BigUint::from(2u128.pow(32) - 1) { - Ok(DeclarationConstant::Concrete( + Ok(DeclarationConstant::from( u32::from_str_radix(&c.to_str_radix(16), 16).unwrap(), )) } else { @@ -1332,9 +1347,9 @@ impl<'ast, T: Field> Checker<'ast, T> { ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, state: &State<'ast, T>, - generics_map: &HashMap, usize>, + generics_map: &BTreeMap, usize>, used_generics: &mut HashSet>, - ) -> Result, ErrorInner> { + ) -> Result, ErrorInner> { let pos = ty.pos(); let ty = ty.value; @@ -1346,7 +1361,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let checked_size = self.check_generic_expression( size.clone(), module_id, - state.constants.get(module_id).unwrap_or(&HashMap::new()), + state.constants.get(module_id).unwrap_or(&BTreeMap::new()), generics_map, used_generics, )?; @@ -1383,7 +1398,7 @@ impl<'ast, T: Field> Checker<'ast, T> { state .constants .get(module_id) - .unwrap_or(&HashMap::new()), + .unwrap_or(&BTreeMap::new()), generics_map, used_generics, ) @@ -1450,7 +1465,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, v: crate::absy::VariableNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, Vec> { Ok(Variable::with_id_and_type( v.value.id, @@ -1466,7 +1481,7 @@ impl<'ast, T: Field> Checker<'ast, T> { statements: Vec>, pos: (Position, Position), module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, Vec> { self.check_for_var(&var).map_err(|e| vec![e])?; @@ -1555,7 +1570,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, stat: StatementNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, Vec> { let pos = stat.pos(); @@ -1816,20 +1831,20 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, assignee: AssigneeNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { let pos = assignee.pos(); // check that the assignee is declared match assignee.value { - Assignee::Identifier(variable_name) => match self.get_scope(&variable_name) { - Some(var) => match var.is_constant() { + Assignee::Identifier(variable_name) => match self.get_key_value_scope(&variable_name) { + Some((id, ty)) => match id.is_constant() { true => Err(ErrorInner { pos: Some(assignee.pos()), message: format!("Assignment to constant variable `{}`", variable_name), }), false => Ok(TypedAssignee::Identifier(Variable::with_id_and_type( variable_name, - var.id._type.clone(), + ty.clone(), ))), }, None => Err(ErrorInner { @@ -1917,7 +1932,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, spread_or_expression: SpreadOrExpression<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { @@ -1947,7 +1962,7 @@ impl<'ast, T: Field> Checker<'ast, T> { &mut self, expr: ExpressionNode<'ast>, module_id: &ModuleId, - types: &TypeMap<'ast>, + types: &TypeMap<'ast, T>, ) -> Result, ErrorInner> { let pos = expr.pos(); @@ -1956,23 +1971,28 @@ impl<'ast, T: Field> Checker<'ast, T> { Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), Expression::Identifier(name) => { // check that `id` is defined in the scope - match self.get_scope(&name) { - Some(v) => match v.id.get_type() { - Type::Boolean => Ok(BooleanExpression::Identifier(name.into()).into()), - Type::Uint(bitwidth) => Ok(UExpressionInner::Identifier(name.into()) + match self + .get_key_value_scope(&name) + .map(|(x, y)| (x.clone(), y.clone())) + { + Some((id, ty)) => match ty { + Type::Boolean => Ok(BooleanExpression::Identifier(id.id.into()).into()), + Type::Uint(bitwidth) => Ok(UExpressionInner::Identifier(id.id.into()) .annotate(bitwidth) .into()), Type::FieldElement => { - Ok(FieldElementExpression::Identifier(name.into()).into()) + Ok(FieldElementExpression::Identifier(id.id.into()).into()) } Type::Array(array_type) => { - Ok(ArrayExpressionInner::Identifier(name.into()) + Ok(ArrayExpressionInner::Identifier(id.id.into()) .annotate(*array_type.ty, array_type.size) .into()) } - Type::Struct(members) => Ok(StructExpressionInner::Identifier(name.into()) - .annotate(members) - .into()), + Type::Struct(members) => { + Ok(StructExpressionInner::Identifier(id.id.into()) + .annotate(members) + .into()) + } Type::Int => unreachable!(), }, None => Err(ErrorInner { @@ -2947,7 +2967,7 @@ impl<'ast, T: Field> Checker<'ast, T> { .clone() .into_iter() .map(|(id, v)| (id.to_string(), v)) - .collect::>(); + .collect::>(); let inferred_values = declared_struct_type .iter() @@ -3247,26 +3267,32 @@ impl<'ast, T: Field> Checker<'ast, T> { } } - fn get_scope<'a>(&'a self, variable_name: &'ast str) -> Option<&'a ScopedVariable<'ast, T>> { - // we take advantage of the fact that all ScopedVariable of the same identifier hash to the same thing - // and by construction only one can be in the set - self.scope.get(&ScopedVariable { - id: Variable::with_id_and_type( - crate::typed_absy::Identifier::from(variable_name), - Type::FieldElement, - ), + fn get_key_value_scope<'a>( + &'a self, + identifier: &'ast str, + ) -> Option<(&'a ScopedIdentifier<'ast>, &'a Type<'ast, T>)> { + self.scope.get_key_value(&ScopedIdentifier { + id: CoreIdentifier::Source(identifier), level: 0, }) } fn insert_into_scope(&mut self, v: Variable<'ast, T>) -> bool { - self.scope.insert(ScopedVariable { - id: v, - level: self.level, - }) + self.scope + .insert( + ScopedIdentifier { + id: v.id.id, + level: self.level, + }, + v._type, + ) + .is_none() } - fn find_functions(&self, query: &FunctionQuery<'ast, T>) -> Vec> { + fn find_functions( + &self, + query: &FunctionQuery<'ast, T>, + ) -> Vec> { query.match_funcs(&self.functions) } @@ -3277,7 +3303,7 @@ impl<'ast, T: Field> Checker<'ast, T> { fn exit_scope(&mut self) { let current_level = self.level; self.scope - .retain(|ref scoped_variable| scoped_variable.level < current_level); + .retain(|ref scoped_variable, _| scoped_variable.level < current_level); self.level -= 1; } } @@ -3998,9 +4024,9 @@ mod tests { } pub fn new_with_args<'ast, T: Field>( - scope: HashSet>, + scope: Scope<'ast, T>, level: usize, - functions: HashSet>, + functions: HashSet>, ) -> Checker<'ast, T> { Checker { scope, @@ -4112,15 +4138,21 @@ mod tests { ) .mock(); - let mut scope = HashSet::new(); - scope.insert(ScopedVariable { - id: Variable::field_element("a"), - level: 1, - }); - scope.insert(ScopedVariable { - id: Variable::field_element("b"), - level: 1, - }); + let mut scope = Scope::default(); + scope.insert( + ScopedIdentifier { + id: CoreIdentifier::Source("a"), + level: 1, + }, + Type::FieldElement, + ); + scope.insert( + ScopedIdentifier { + id: CoreIdentifier::Source("b"), + level: 1, + }, + Type::FieldElement, + ); let mut checker: Checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index c318c9c91..2ba585f49 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -1,9 +1,7 @@ use crate::static_analysis::Propagator; use crate::typed_absy::result_folder::*; -use crate::typed_absy::types::DeclarationConstant; use crate::typed_absy::*; use std::collections::HashMap; -use std::convert::TryInto; use std::fmt; use zokrates_field::Field; @@ -71,20 +69,6 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { .and_then(|constants| constants.get(&id.id)) .cloned() } - - fn get_constant_for_identifier(&self, id: &Identifier<'ast>) -> Option> { - match &id.id { - // canonical constants can be accessed directly in the constant map - CoreIdentifier::Constant(c) => self.get_constant(c), - // source ids are checked against the canonical constant map, setting the module to the current module - CoreIdentifier::Source(id) => self - .constants - .get(&self.location) - .and_then(|constants| constants.get(id)) - .cloned(), - _ => unreachable!(), - } - } } impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { @@ -178,114 +162,111 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { }) } - fn fold_declaration_constant( - &mut self, - c: DeclarationConstant<'ast>, - ) -> Result, Self::Error> { - match c { - // replace constants by their concrete value in declaration types - DeclarationConstant::Constant(id) => { - let id = CanonicalConstantIdentifier { - module: self.fold_module_id(id.module)?, - ..id - }; - - match self.get_constant(&id).unwrap() { - TypedConstant { - expression: TypedExpression::Uint(UExpression { - inner: UExpressionInner::Value(v), - .. - }), - ty: DeclarationType::Uint(UBitwidth::B32) - } => Ok(DeclarationConstant::Concrete(v as u32)), - c => Err(Error::Propagation(format!("Failed to reduce `{}` to a single u32 literal, try avoiding function calls in the definition of `{}` in {}", c, id.id, id.module.display()))) - } - } - c => Ok(c), - } - } - - fn fold_field_expression( - &mut self, - e: FieldElementExpression<'ast, T>, - ) -> Result, Self::Error> { - match e { - FieldElementExpression::Identifier(ref id) => { - match self.get_constant_for_identifier(id) { - Some(c) => Ok(c.try_into().unwrap()), - None => fold_field_expression(self, e), - } - } - e => fold_field_expression(self, e), - } - } - - fn fold_boolean_expression( - &mut self, - e: BooleanExpression<'ast, T>, - ) -> Result, Self::Error> { - match e { - BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) { - Some(c) => Ok(c.try_into().unwrap()), - None => fold_boolean_expression(self, e), - }, - e => fold_boolean_expression(self, e), - } - } - - fn fold_uint_expression_inner( - &mut self, - size: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) { - Some(c) => { - let e: UExpression<'ast, T> = c.try_into().unwrap(); - Ok(e.into_inner()) - } - None => fold_uint_expression_inner(self, size, e), - }, - e => fold_uint_expression_inner(self, size, e), - } - } - - fn fold_array_expression_inner( - &mut self, - ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - ArrayExpressionInner::Identifier(ref id) => { - match self.get_constant_for_identifier(id) { - Some(c) => { - let e: ArrayExpression<'ast, T> = c.try_into().unwrap(); - Ok(e.into_inner()) - } - None => fold_array_expression_inner(self, ty, e), - } - } - e => fold_array_expression_inner(self, ty, e), - } - } - - fn fold_struct_expression_inner( - &mut self, - ty: &StructType<'ast, T>, - e: StructExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) - { - Some(c) => { - let e: StructExpression<'ast, T> = c.try_into().unwrap(); - Ok(e.into_inner()) - } - None => fold_struct_expression_inner(self, ty, e), - }, - e => fold_struct_expression_inner(self, ty, e), - } - } + // fn fold_declaration_constant( + // &mut self, + // c: DeclarationConstant<'ast, T>, + // ) -> Result, Self::Error> { + // match c { + // // replace constants by their concrete value in declaration types + // DeclarationConstant::Constant(id) => { + // let id = CanonicalConstantIdentifier { + // module: self.fold_module_id(id.module)?, + // ..id + // }; + + // match self.get_constant(&id).unwrap() { + // TypedConstant { + // ty: DeclarationType::Uint(UBitwidth::B32), + // expression + // } => Ok(DeclarationConstant::Expression(expression)), + // c => Err(Error::Propagation(format!("Failed to reduce `{}` to a single u32 literal, try avoiding function calls in the definition of `{}` in {}", c, id.id, id.module.display()))) + // } + // } + // c => Ok(c), + // } + // } + + // fn fold_field_expression( + // &mut self, + // e: FieldElementExpression<'ast, T>, + // ) -> Result, Self::Error> { + // match e { + // FieldElementExpression::Identifier(ref id) => { + // match self.get_constant_for_identifier(id) { + // Some(c) => Ok(c.try_into().unwrap()), + // None => fold_field_expression(self, e), + // } + // } + // e => fold_field_expression(self, e), + // } + // } + + // fn fold_boolean_expression( + // &mut self, + // e: BooleanExpression<'ast, T>, + // ) -> Result, Self::Error> { + // match e { + // BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) { + // Some(c) => Ok(c.try_into().unwrap()), + // None => fold_boolean_expression(self, e), + // }, + // e => fold_boolean_expression(self, e), + // } + // } + + // fn fold_uint_expression_inner( + // &mut self, + // size: UBitwidth, + // e: UExpressionInner<'ast, T>, + // ) -> Result, Self::Error> { + // match e { + // UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) { + // Some(c) => { + // let e: UExpression<'ast, T> = c.try_into().unwrap(); + // Ok(e.into_inner()) + // } + // None => fold_uint_expression_inner(self, size, e), + // }, + // e => fold_uint_expression_inner(self, size, e), + // } + // } + + // fn fold_array_expression_inner( + // &mut self, + // ty: &ArrayType<'ast, T>, + // e: ArrayExpressionInner<'ast, T>, + // ) -> Result, Self::Error> { + // match e { + // ArrayExpressionInner::Identifier(ref id) => { + // match self.get_constant_for_identifier(id) { + // Some(c) => { + // let e: ArrayExpression<'ast, T> = c.try_into().unwrap(); + // Ok(e.into_inner()) + // } + // None => fold_array_expression_inner(self, ty, e), + // } + // } + // e => fold_array_expression_inner(self, ty, e), + // } + // } + + // fn fold_struct_expression_inner( + // &mut self, + // ty: &StructType<'ast, T>, + // e: StructExpressionInner<'ast, T>, + // ) -> Result, Self::Error> { + // match e { + // StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) + // { + // Some(c) => { + // let e: StructExpression<'ast, T> = c.try_into().unwrap(); + // Ok(e.into_inner()) + // } + // None => fold_struct_expression_inner(self, ty, e), + // }, + // e => fold_struct_expression_inner(self, ty, e), + // } + // } } #[cfg(test)] diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index f3973297a..4113c74c6 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -126,7 +126,7 @@ impl<'ast, T: Field> Flattener { fn fold_declaration_parameter( &mut self, - p: typed_absy::DeclarationParameter<'ast>, + p: typed_absy::DeclarationParameter<'ast, T>, ) -> Vec> { let private = p.private; self.fold_variable(crate::typed_absy::variable::try_from_g_variable(p.id).unwrap()) @@ -1093,7 +1093,7 @@ fn fold_function<'ast, T: Field>( statements: main_statements_buffer, signature: typed_absy::types::ConcreteSignature::try_from( crate::typed_absy::types::try_from_g_signature::< - crate::typed_absy::types::DeclarationConstant<'ast>, + crate::typed_absy::types::DeclarationConstant<'ast, T>, crate::typed_absy::UExpression<'ast, T>, >(fun.signature) .unwrap(), diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index b90fc99d6..50eaae27e 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -16,7 +16,7 @@ use std::convert::{TryFrom, TryInto}; use std::fmt; use zokrates_field::Field; -type Constants<'ast, T> = HashMap, TypedExpression<'ast, T>>; +pub type Constants<'ast, T> = HashMap, TypedExpression<'ast, T>>; #[derive(Debug, PartialEq)] pub enum Error { @@ -45,6 +45,7 @@ impl fmt::Display for Error { } } +#[derive(Debug)] pub struct Propagator<'ast, 'a, T: Field> { // constants keeps track of constant expressions // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is diff --git a/zokrates_core/src/static_analysis/reducer/inline.rs b/zokrates_core/src/static_analysis/reducer/inline.rs index d6917de03..d0814b89e 100644 --- a/zokrates_core/src/static_analysis/reducer/inline.rs +++ b/zokrates_core/src/static_analysis/reducer/inline.rs @@ -41,7 +41,7 @@ use crate::typed_absy::{ use zokrates_field::Field; pub enum InlineError<'ast, T> { - Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>), + Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>), Flat( FlatEmbed, Vec, @@ -49,7 +49,7 @@ pub enum InlineError<'ast, T> { Types<'ast, T>, ), NonConstant( - DeclarationFunctionKey<'ast>, + DeclarationFunctionKey<'ast, T>, Vec>>, Vec>, Types<'ast, T>, @@ -57,9 +57,12 @@ pub enum InlineError<'ast, T> { } fn get_canonical_function<'ast, T: Field>( - function_key: DeclarationFunctionKey<'ast>, + function_key: DeclarationFunctionKey<'ast, T>, program: &TypedProgram<'ast, T>, -) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) { +) -> ( + DeclarationFunctionKey<'ast, T>, + TypedFunctionSymbol<'ast, T>, +) { match program .modules .get(&function_key.module) @@ -80,7 +83,7 @@ type InlineResult<'ast, T> = Result< >; pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( - k: DeclarationFunctionKey<'ast>, + k: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, output: &E::Ty, @@ -155,7 +158,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( TypedFunctionSymbol::Here(f) => Ok(f), TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat( e, - e.generics(&assignment), + e.generics::(&assignment), arguments.clone(), output_types, )), diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 84790f29b..cb2ba50e1 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -18,26 +18,173 @@ use self::inline::{inline_call, InlineError}; use crate::typed_absy::result_folder::*; use crate::typed_absy::types::ConcreteGenericsAssignment; use crate::typed_absy::types::GGenericsAssignment; +use crate::typed_absy::CanonicalConstantIdentifier; use crate::typed_absy::Folder; +use crate::typed_absy::UBitwidth; use std::collections::HashMap; use crate::typed_absy::{ - ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, - FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, TypedExpression, - TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule, - TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable, + ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, DeclarationConstant, + DeclarationSignature, Expr, FieldElementExpression, FunctionCall, FunctionCallExpression, + FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, TypedConstant, + TypedConstantSymbol, TypedExpression, TypedExpressionList, TypedExpressionListInner, + TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, + UExpressionInner, Variable, }; +use std::convert::{TryFrom, TryInto}; use zokrates_field::Field; use self::shallow_ssa::ShallowTransformer; -use crate::static_analysis::Propagator; +use crate::static_analysis::propagation::{Constants, Propagator}; use std::fmt; const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20); +// A map to register the canonical value of all constants. The values must be literals. +type ConstantDefinitions<'ast, T> = + HashMap, TypedExpression<'ast, T>>; + +// A folder to inline all constant definitions down to a single litteral. Also register them in the state for later use. +struct ConstantCallsInliner<'ast, T> { + constants: ConstantDefinitions<'ast, T>, + program: TypedProgram<'ast, T>, +} + +impl<'ast, T> ConstantCallsInliner<'ast, T> { + fn with_program(program: TypedProgram<'ast, T>) -> Self { + ConstantCallsInliner { + constants: ConstantDefinitions::default(), + program, + } + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { + type Error = Error; + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + match dbg!(e) { + FieldElementExpression::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + Ok(self.constants.get(&c).cloned().unwrap().try_into().unwrap()) + } + e => fold_field_expression(self, e), + } + } + + fn fold_uint_expression_inner( + &mut self, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match dbg!(e) { + UExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + Ok( + UExpression::try_from(self.constants.get(&c).cloned().unwrap()) + .unwrap() + .into_inner(), + ) + } + e => fold_uint_expression_inner(self, ty, e), + } + } + + fn fold_declaration_constant( + &mut self, + c: DeclarationConstant<'ast, T>, + ) -> Result, Self::Error> { + match c { + DeclarationConstant::Constant(c) => { + if let UExpressionInner::Value(v) = + UExpression::try_from(self.constants.get(&c).cloned().unwrap()) + .unwrap() + .into_inner() + { + Ok(DeclarationConstant::Concrete(v as u32)) + } else { + unreachable!() + } + } + c => fold_declaration_constant(self, c), + } + } + + fn fold_module( + &mut self, + m: TypedModule<'ast, T>, + ) -> Result, Self::Error> { + Ok(TypedModule { + constants: m + .constants + .into_iter() + .map(|(key, tc)| match tc { + TypedConstantSymbol::Here(c) => { + let wrapper = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![c.expression])], + signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]), + }; + + let mut inlined_wrapper = reduce_function( + wrapper, + ConcreteGenericsAssignment::default(), + &self.program, + )?; + + if inlined_wrapper.statements.len() > 1 { + return Err(Error::ConstantReduction(key.id.to_string(), key.module)); + }; + + if let TypedStatement::Return(mut expressions) = + inlined_wrapper.statements.pop().unwrap() + { + assert_eq!(expressions.len(), 1); + let constant_expression = expressions.pop().unwrap(); + use crate::typed_absy::Constant; + if !constant_expression.is_constant() { + return Err(Error::ConstantReduction( + key.id.to_string(), + key.module, + )); + }; + self.constants + .insert(key.clone(), constant_expression.clone()); + Ok(( + key, + TypedConstantSymbol::Here(TypedConstant { + expression: constant_expression, + ty: c.ty, + }), + )) + } else { + return Err(Error::ConstantReduction(key.id.to_string(), key.module)); + } + } + _ => unreachable!("all constants should be local"), + }) + .collect::, _>>()?, + functions: m + .functions + .into_iter() + .map(|(key, fun)| self.fold_function_symbol(fun).map(|f| (key, f))) + .collect::>()?, + }) + } +} + // An SSA version map, giving access to the latest version number for each identifier pub type Versions<'ast> = HashMap, usize>; @@ -55,6 +202,7 @@ pub enum Error { // TODO: give more details about what's blocking the progress NoProgress, LoopTooLarge(u128), + ConstantReduction(String, OwnedTypedModuleId), } impl fmt::Display for Error { @@ -68,6 +216,7 @@ impl fmt::Display for Error { Error::GenericsInMain => write!(f, "Cannot generate code for generic function"), Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"), Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE), + Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()), } } } @@ -159,6 +308,7 @@ fn register<'ast>( } } +#[derive(Debug)] struct Reducer<'ast, 'a, T> { statement_buffer: Vec>, for_loop_versions: Vec>, @@ -304,10 +454,19 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { }) } + fn fold_canonical_constant_identifier( + &mut self, + _: CanonicalConstantIdentifier<'ast>, + ) -> Result, Self::Error> { + unreachable!("canonical constant identifiers should not be folded, they should be inlined") + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { + println!("STAT {}", s); + let res = match s { TypedStatement::MultipleDefinition( v, @@ -487,6 +646,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { } pub fn reduce_program(p: TypedProgram) -> Result, Error> { + // inline all constants and replace them in the program + let mut constant_calls_inliner = ConstantCallsInliner::with_program(p.clone()); + + let p = constant_calls_inliner.fold_program(p)?; + + // inline starting from main let main_module = p.modules.get(&p.main).unwrap().clone(); let (main_key, main_function) = main_module @@ -542,7 +707,7 @@ fn reduce_function<'ast, T: Field>( let mut substitutions = Substitutions::default(); - let mut constants: HashMap, TypedExpression<'ast, T>> = HashMap::new(); + let mut constants = Constants::default(); let mut hash = None; diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index cf2bb958e..fae489b67 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -67,8 +67,8 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_declaration_function_key( &mut self, - key: DeclarationFunctionKey<'ast>, - ) -> DeclarationFunctionKey<'ast> { + key: DeclarationFunctionKey<'ast, T>, + ) -> DeclarationFunctionKey<'ast, T> { fold_declaration_function_key(self, key) } @@ -76,18 +76,24 @@ pub trait Folder<'ast, T: Field>: Sized { fold_function(self, f) } - fn fold_signature(&mut self, s: DeclarationSignature<'ast>) -> DeclarationSignature<'ast> { + fn fold_signature( + &mut self, + s: DeclarationSignature<'ast, T>, + ) -> DeclarationSignature<'ast, T> { fold_signature(self, s) } fn fold_declaration_constant( &mut self, - c: DeclarationConstant<'ast>, - ) -> DeclarationConstant<'ast> { + c: DeclarationConstant<'ast, T>, + ) -> DeclarationConstant<'ast, T> { fold_declaration_constant(self, c) } - fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> { + fn fold_parameter( + &mut self, + p: DeclarationParameter<'ast, T>, + ) -> DeclarationParameter<'ast, T> { DeclarationParameter { id: self.fold_declaration_variable(p.id), ..p @@ -107,8 +113,8 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_declaration_variable( &mut self, - v: DeclarationVariable<'ast>, - ) -> DeclarationVariable<'ast> { + v: DeclarationVariable<'ast, T>, + ) -> DeclarationVariable<'ast, T> { DeclarationVariable { id: self.fold_name(v.id), _type: self.fold_declaration_type(v._type), @@ -155,7 +161,7 @@ pub trait Folder<'ast, T: Field>: Sized { } } - fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> { + fn fold_declaration_type(&mut self, t: DeclarationType<'ast, T>) -> DeclarationType<'ast, T> { use self::GType::*; match t { @@ -167,8 +173,8 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_declaration_array_type( &mut self, - t: DeclarationArrayType<'ast>, - ) -> DeclarationArrayType<'ast> { + t: DeclarationArrayType<'ast, T>, + ) -> DeclarationArrayType<'ast, T> { DeclarationArrayType { ty: box self.fold_declaration_type(*t.ty), size: self.fold_declaration_constant(t.size), @@ -177,8 +183,8 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_declaration_struct_type( &mut self, - t: DeclarationStructType<'ast>, - ) -> DeclarationStructType<'ast> { + t: DeclarationStructType<'ast, T>, + ) -> DeclarationStructType<'ast, T> { DeclarationStructType { generics: t .generics @@ -901,8 +907,8 @@ pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T pub fn fold_declaration_function_key<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - key: DeclarationFunctionKey<'ast>, -) -> DeclarationFunctionKey<'ast> { + key: DeclarationFunctionKey<'ast, T>, +) -> DeclarationFunctionKey<'ast, T> { DeclarationFunctionKey { module: f.fold_module_id(key.module), signature: f.fold_signature(key.signature), @@ -954,8 +960,8 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - s: DeclarationSignature<'ast>, -) -> DeclarationSignature<'ast> { + s: DeclarationSignature<'ast, T>, +) -> DeclarationSignature<'ast, T> { DeclarationSignature { generics: s.generics, inputs: s @@ -972,10 +978,13 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>( } fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>( - _: &mut F, - c: DeclarationConstant<'ast>, -) -> DeclarationConstant<'ast> { - c + f: &mut F, + c: DeclarationConstant<'ast, T>, +) -> DeclarationConstant<'ast, T> { + match c { + DeclarationConstant::Expression(e) => DeclarationConstant::Expression(f.fold_expression(e)), + c => c, + } } pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( diff --git a/zokrates_core/src/typed_absy/identifier.rs b/zokrates_core/src/typed_absy/identifier.rs index 6b2420749..d60226713 100644 --- a/zokrates_core/src/typed_absy/identifier.rs +++ b/zokrates_core/src/typed_absy/identifier.rs @@ -2,7 +2,7 @@ use crate::typed_absy::CanonicalConstantIdentifier; use std::convert::TryInto; use std::fmt; -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] pub enum CoreIdentifier<'ast> { Source(&'ast str), Call(usize), @@ -32,7 +32,7 @@ impl<'ast> From> for CoreIdentifier<'ast> { } /// A identifier for a variable -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord)] pub struct Identifier<'ast> { /// the id of the variable pub id: CoreIdentifier<'ast>, diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index 62eba07c3..b430f3a92 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -40,7 +40,7 @@ trait IntegerInference: Sized { } impl<'ast, T> IntegerInference for Type<'ast, T> { - type Pattern = DeclarationType<'ast>; + type Pattern = DeclarationType<'ast, T>; fn get_common_pattern(self, other: Self) -> Result { match (self, other) { @@ -72,7 +72,7 @@ impl<'ast, T> IntegerInference for Type<'ast, T> { } impl<'ast, T> IntegerInference for ArrayType<'ast, T> { - type Pattern = DeclarationArrayType<'ast>; + type Pattern = DeclarationArrayType<'ast, T>; fn get_common_pattern(self, other: Self) -> Result { let s0 = self.size; @@ -88,7 +88,7 @@ impl<'ast, T> IntegerInference for ArrayType<'ast, T> { } impl<'ast, T> IntegerInference for StructType<'ast, T> { - type Pattern = DeclarationStructType<'ast>; + type Pattern = DeclarationStructType<'ast, T>; fn get_common_pattern(self, other: Self) -> Result { Ok(DeclarationStructType { @@ -228,7 +228,7 @@ impl<'ast, T: Field> TypedExpression<'ast, T> { } } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] pub enum IntExpression<'ast, T> { Value(BigUint), Pos(Box>), @@ -424,7 +424,7 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { v, &DeclarationArrayType::new( DeclarationType::FieldElement, - DeclarationConstant::Concrete(0), + DeclarationConstant::from(0u32), ), ) .map_err(|(e, _)| match e { @@ -542,7 +542,7 @@ impl<'ast, T: Field> UExpression<'ast, T> { v, &DeclarationArrayType::new( DeclarationType::Uint(*bitwidth), - DeclarationConstant::Concrete(0), + DeclarationConstant::from(0u32), ), ) .map_err(|(e, _)| match e { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 34ef85904..e2f58456a 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -20,9 +20,9 @@ pub use self::identifier::CoreIdentifier; pub use self::parameter::{DeclarationParameter, GParameter}; pub use self::types::{ CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType, - ConstantIdentifier, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature, - DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier, - IntoTypes, Signature, StructType, Type, Types, UBitwidth, + ConstantIdentifier, DeclarationArrayType, DeclarationConstant, DeclarationFunctionKey, + DeclarationSignature, DeclarationStructType, DeclarationType, GArrayType, GStructType, GType, + GenericIdentifier, IntoTypes, Signature, StructType, Type, Types, UBitwidth, }; use crate::typed_absy::types::ConcreteGenericsAssignment; @@ -35,7 +35,7 @@ pub use crate::typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMeta use crate::embed::FlatEmbed; -use std::collections::HashMap; +use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; use std::fmt; @@ -54,14 +54,14 @@ pub type OwnedTypedModuleId = PathBuf; pub type TypedModuleId = Path; /// A collection of `TypedModule`s -pub type TypedModules<'ast, T> = HashMap>; +pub type TypedModules<'ast, T> = BTreeMap>; /// A collection of `TypedFunctionSymbol`s /// # Remarks /// * It is the role of the semantic checker to make sure there are no duplicates for a given `FunctionKey` -/// in a given `TypedModule`, hence the use of a HashMap +/// in a given `TypedModule`, hence the use of a BTreeMap pub type TypedFunctionSymbols<'ast, T> = - HashMap, TypedFunctionSymbol<'ast, T>>; + BTreeMap, TypedFunctionSymbol<'ast, T>>; #[derive(Clone, Debug, PartialEq)] pub enum TypedConstantSymbol<'ast, T> { @@ -109,7 +109,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { .map(|p| { types::ConcreteType::try_from( crate::typed_absy::types::try_from_g_type::< - crate::typed_absy::types::DeclarationConstant<'ast>, + DeclarationConstant<'ast, T>, UExpression<'ast, T>, >(p.id._type.clone()) .unwrap(), @@ -129,7 +129,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { .map(|ty| { types::ConcreteType::try_from( crate::typed_absy::types::try_from_g_type::< - crate::typed_absy::types::DeclarationConstant<'ast>, + DeclarationConstant<'ast, T>, UExpression<'ast, T>, >(ty.clone()) .unwrap(), @@ -175,7 +175,7 @@ pub struct TypedModule<'ast, T> { #[derive(Clone, PartialEq, Debug)] pub enum TypedFunctionSymbol<'ast, T> { Here(TypedFunction<'ast, T>), - There(DeclarationFunctionKey<'ast>), + There(DeclarationFunctionKey<'ast, T>), Flat(FlatEmbed), } @@ -183,7 +183,7 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { pub fn signature<'a>( &'a self, modules: &'a TypedModules<'ast, T>, - ) -> DeclarationSignature<'ast> { + ) -> DeclarationSignature<'ast, T> { match self { TypedFunctionSymbol::Here(f) => f.signature.clone(), TypedFunctionSymbol::There(key) => modules @@ -193,7 +193,7 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { .get(key) .unwrap() .signature(&modules), - TypedFunctionSymbol::Flat(flat_fun) => flat_fun.signature(), + TypedFunctionSymbol::Flat(flat_fun) => flat_fun.typed_signature(), } } } @@ -226,7 +226,11 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { key.signature ), TypedFunctionSymbol::Flat(ref flat_fun) => { - format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature()) + format!( + "def {}{}:\n\t// hidden", + key.id, + flat_fun.typed_signature::() + ) } })) .collect::>(); @@ -239,11 +243,11 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { #[derive(Clone, PartialEq, Debug, Hash)] pub struct TypedFunction<'ast, T> { /// Arguments of the function - pub arguments: Vec>, + pub arguments: Vec>, /// Vector of statements that are executed when running the function pub statements: Vec>, /// function signature - pub signature: DeclarationSignature<'ast>, + pub signature: DeclarationSignature<'ast, T>, } impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { @@ -312,11 +316,11 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { #[derive(Clone, PartialEq, Debug)] pub struct TypedConstant<'ast, T> { pub expression: TypedExpression<'ast, T>, - pub ty: DeclarationType<'ast>, + pub ty: DeclarationType<'ast, T>, } impl<'ast, T> TypedConstant<'ast, T> { - pub fn new(expression: TypedExpression<'ast, T>, ty: DeclarationType<'ast>) -> Self { + pub fn new(expression: TypedExpression<'ast, T>, ty: DeclarationType<'ast, T>) -> Self { TypedConstant { expression, ty } } } @@ -334,14 +338,14 @@ impl<'ast, T: Field> Typed<'ast, T> for TypedConstant<'ast, T> { } /// Something we can assign to. -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedAssignee<'ast, T> { Identifier(Variable<'ast, T>), Select(Box>, Box>), Member(Box>, MemberId), } -#[derive(Clone, PartialEq, Hash, Eq, Debug)] +#[derive(Clone, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)] pub struct TypedSpread<'ast, T> { pub array: ArrayExpression<'ast, T>, } @@ -358,7 +362,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedSpread<'ast, T> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedExpressionOrSpread<'ast, T> { Expression(TypedExpression<'ast, T>), Spread(TypedSpread<'ast, T>), @@ -488,7 +492,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> { /// A statement in a `TypedFunction` #[allow(clippy::large_enum_variant)] -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedStatement<'ast, T> { Return(Vec>), Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), @@ -503,7 +507,7 @@ pub enum TypedStatement<'ast, T> { MultipleDefinition(Vec>, TypedExpressionList<'ast, T>), // Aux PushCallLog( - DeclarationFunctionKey<'ast>, + DeclarationFunctionKey<'ast, T>, ConcreteGenericsAssignment<'ast>, ), PopCallLog, @@ -576,7 +580,7 @@ pub trait Typed<'ast, T> { /// A typed expression #[allow(clippy::large_enum_variant)] -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedExpression<'ast, T> { Boolean(BooleanExpression<'ast, T>), FieldElement(FieldElementExpression<'ast, T>), @@ -715,7 +719,7 @@ pub trait MultiTyped<'ast, T> { fn get_types(&self) -> &Vec>; } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct TypedExpressionList<'ast, T> { pub inner: TypedExpressionListInner<'ast, T>, @@ -728,7 +732,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum TypedExpressionListInner<'ast, T> { FunctionCall(FunctionCallExpression<'ast, T, TypedExpressionList<'ast, T>>), EmbedCall(FlatEmbed, Vec, Vec>), @@ -745,7 +749,7 @@ impl<'ast, T> TypedExpressionListInner<'ast, T> { TypedExpressionList { inner: self, types } } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct BlockExpression<'ast, T, E> { pub statements: Vec>, pub value: Box, @@ -760,7 +764,7 @@ impl<'ast, T, E> BlockExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct MemberExpression<'ast, T, E> { pub struc: Box>, pub id: MemberId, @@ -783,7 +787,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for MemberExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct SelectExpression<'ast, T, E> { pub array: Box>, pub index: Box>, @@ -806,7 +810,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct IfElseExpression<'ast, T, E> { pub condition: Box>, pub consequence: Box, @@ -833,9 +837,9 @@ impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<' } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct FunctionCallExpression<'ast, T, E> { - pub function_key: DeclarationFunctionKey<'ast>, + pub function_key: DeclarationFunctionKey<'ast, T>, pub generics: Vec>>, pub arguments: Vec>, ty: PhantomData, @@ -843,7 +847,7 @@ pub struct FunctionCallExpression<'ast, T, E> { impl<'ast, T, E> FunctionCallExpression<'ast, T, E> { pub fn new( - function_key: DeclarationFunctionKey<'ast>, + function_key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self { @@ -886,7 +890,7 @@ impl<'ast, T: fmt::Display, E> fmt::Display for FunctionCallExpression<'ast, T, } /// An expression of type `field` -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum FieldElementExpression<'ast, T> { Block(BlockExpression<'ast, T, Self>), Number(T), @@ -963,7 +967,7 @@ impl<'ast, T> From for FieldElementExpression<'ast, T> { } /// An expression of type `bool` -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum BooleanExpression<'ast, T> { Block(BlockExpression<'ast, T, Self>), Identifier(Identifier<'ast>), @@ -1028,13 +1032,13 @@ impl<'ast, T> From for BooleanExpression<'ast, T> { /// * Contrary to basic types which are represented as enums, we wrap an enum `ArrayExpressionInner` in a struct in order to keep track of the type (content and size) /// of the array. Only using an enum would require generics, which would propagate up to TypedExpression which we want to keep simple, hence this "runtime" /// type checking -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct ArrayExpression<'ast, T> { pub ty: Box>, pub inner: ArrayExpressionInner<'ast, T>, } -#[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] pub struct ArrayValue<'ast, T>(pub Vec>); impl<'ast, T> From>> for ArrayValue<'ast, T> { @@ -1112,7 +1116,7 @@ impl<'ast, T> std::iter::FromIterator> for Arra } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum ArrayExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>), Identifier(Identifier<'ast>), @@ -1152,7 +1156,7 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub struct StructExpression<'ast, T> { ty: StructType<'ast, T>, inner: StructExpressionInner<'ast, T>, @@ -1176,7 +1180,7 @@ impl<'ast, T> StructExpression<'ast, T> { } } -#[derive(Clone, PartialEq, Debug, Hash, Eq)] +#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)] pub enum StructExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, StructExpression<'ast, T>>), Identifier(Identifier<'ast>), @@ -1898,7 +1902,7 @@ impl<'ast, T: Field> Id<'ast, T> for TypedExpressionList<'ast, T> { pub trait FunctionCall<'ast, T>: Expr<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner; @@ -1906,7 +1910,7 @@ pub trait FunctionCall<'ast, T>: Expr<'ast, T> { impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner { @@ -1916,7 +1920,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> { impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner { @@ -1926,7 +1930,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> { impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner { @@ -1936,7 +1940,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> { impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner { @@ -1946,7 +1950,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> { impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner { @@ -1956,7 +1960,7 @@ impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> { impl<'ast, T: Field> FunctionCall<'ast, T> for TypedExpressionList<'ast, T> { fn function_call( - key: DeclarationFunctionKey<'ast>, + key: DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, ) -> Self::Inner { diff --git a/zokrates_core/src/typed_absy/parameter.rs b/zokrates_core/src/typed_absy/parameter.rs index 454219455..34dbb5b0c 100644 --- a/zokrates_core/src/typed_absy/parameter.rs +++ b/zokrates_core/src/typed_absy/parameter.rs @@ -18,12 +18,12 @@ impl<'ast, S> From> for GParameter<'ast, S> { } } -pub type DeclarationParameter<'ast> = GParameter<'ast, DeclarationConstant<'ast>>; +pub type DeclarationParameter<'ast, T> = GParameter<'ast, DeclarationConstant<'ast, T>>; -impl<'ast, S: fmt::Display + Clone> fmt::Display for GParameter<'ast, S> { +impl<'ast, S: fmt::Display> fmt::Display for GParameter<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let visibility = if self.private { "private " } else { "" }; - write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id) + write!(f, "{}{} {}", visibility, self.id._type, self.id.id) } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index f19062446..218f53867 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -78,8 +78,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_declaration_function_key( &mut self, - key: DeclarationFunctionKey<'ast>, - ) -> Result, Self::Error> { + key: DeclarationFunctionKey<'ast, T>, + ) -> Result, Self::Error> { fold_declaration_function_key(self, key) } @@ -92,22 +92,22 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_signature( &mut self, - s: DeclarationSignature<'ast>, - ) -> Result, Self::Error> { + s: DeclarationSignature<'ast, T>, + ) -> Result, Self::Error> { fold_signature(self, s) } fn fold_declaration_constant( &mut self, - c: DeclarationConstant<'ast>, - ) -> Result, Self::Error> { + c: DeclarationConstant<'ast, T>, + ) -> Result, Self::Error> { fold_declaration_constant(self, c) } fn fold_parameter( &mut self, - p: DeclarationParameter<'ast>, - ) -> Result, Self::Error> { + p: DeclarationParameter<'ast, T>, + ) -> Result, Self::Error> { Ok(DeclarationParameter { id: self.fold_declaration_variable(p.id)?, ..p @@ -141,8 +141,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_declaration_variable( &mut self, - v: DeclarationVariable<'ast>, - ) -> Result, Self::Error> { + v: DeclarationVariable<'ast, T>, + ) -> Result, Self::Error> { Ok(DeclarationVariable { id: self.fold_name(v.id)?, _type: self.fold_declaration_type(v._type)?, @@ -245,8 +245,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_declaration_type( &mut self, - t: DeclarationType<'ast>, - ) -> Result, Self::Error> { + t: DeclarationType<'ast, T>, + ) -> Result, Self::Error> { use self::GType::*; match t { @@ -258,8 +258,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_declaration_array_type( &mut self, - t: DeclarationArrayType<'ast>, - ) -> Result, Self::Error> { + t: DeclarationArrayType<'ast, T>, + ) -> Result, Self::Error> { Ok(DeclarationArrayType { ty: box self.fold_declaration_type(*t.ty)?, size: self.fold_declaration_constant(t.size)?, @@ -268,8 +268,8 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fn fold_declaration_struct_type( &mut self, - t: DeclarationStructType<'ast>, - ) -> Result, Self::Error> { + t: DeclarationStructType<'ast, T>, + ) -> Result, Self::Error> { Ok(DeclarationStructType { generics: t .generics @@ -971,8 +971,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, - key: DeclarationFunctionKey<'ast>, -) -> Result, F::Error> { + key: DeclarationFunctionKey<'ast, T>, +) -> Result, F::Error> { Ok(DeclarationFunctionKey { module: f.fold_module_id(key.module)?, signature: f.fold_signature(key.signature)?, @@ -1002,10 +1002,10 @@ pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( +pub fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, - s: DeclarationSignature<'ast>, -) -> Result, F::Error> { + s: DeclarationSignature<'ast, T>, +) -> Result, F::Error> { Ok(DeclarationSignature { generics: s.generics, inputs: s @@ -1021,11 +1021,16 @@ fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( }) } -fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>( - _: &mut F, - c: DeclarationConstant<'ast>, -) -> Result, F::Error> { - Ok(c) +pub fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + c: DeclarationConstant<'ast, T>, +) -> Result, F::Error> { + match c { + DeclarationConstant::Expression(e) => { + Ok(DeclarationConstant::Expression(f.fold_expression(e)?)) + } + c => Ok(c), + } } pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 9d8268d57..7520b4b79 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,5 +1,5 @@ use crate::typed_absy::{ - CoreIdentifier, Identifier, OwnedTypedModuleId, UExpression, UExpressionInner, + CoreIdentifier, Identifier, OwnedTypedModuleId, TypedExpression, UExpression, UExpressionInner, }; use crate::typed_absy::{TryFrom, TryInto}; use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; @@ -48,7 +48,7 @@ impl<'ast, T> IntoTypes<'ast, T> for Types<'ast, T> { } } -#[derive(Debug, Clone, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub struct Types<'ast, T> { pub inner: Vec>, } @@ -118,51 +118,62 @@ impl<'ast> CanonicalConstantIdentifier<'ast> { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum DeclarationConstant<'ast> { +pub enum DeclarationConstant<'ast, T> { Generic(GenericIdentifier<'ast>), Concrete(u32), Constant(CanonicalConstantIdentifier<'ast>), + Expression(TypedExpression<'ast, T>), } -impl<'ast, T> PartialEq> for DeclarationConstant<'ast> { +impl<'ast, T: PartialEq> PartialEq> for DeclarationConstant<'ast, T> { fn eq(&self, other: &UExpression<'ast, T>) -> bool { - match (self, other.as_inner()) { - (DeclarationConstant::Concrete(c), UExpressionInner::Value(v)) => *c == *v as u32, + match (self, other) { + ( + DeclarationConstant::Concrete(c), + UExpression { + bitwidth: UBitwidth::B32, + inner: UExpressionInner::Value(v), + .. + }, + ) => *c == *v as u32, + (DeclarationConstant::Expression(TypedExpression::Uint(e0)), e1) => e0 == e1, + (DeclarationConstant::Expression(..), _) => false, // type error _ => true, } } } -impl<'ast, T> PartialEq> for UExpression<'ast, T> { - fn eq(&self, other: &DeclarationConstant<'ast>) -> bool { +impl<'ast, T: PartialEq> PartialEq> for UExpression<'ast, T> { + fn eq(&self, other: &DeclarationConstant<'ast, T>) -> bool { other.eq(self) } } -impl<'ast> From for DeclarationConstant<'ast> { +impl<'ast, T> From for DeclarationConstant<'ast, T> { fn from(e: u32) -> Self { DeclarationConstant::Concrete(e) } } -impl<'ast> From for DeclarationConstant<'ast> { +impl<'ast, T> From for DeclarationConstant<'ast, T> { fn from(e: usize) -> Self { DeclarationConstant::Concrete(e as u32) } } -impl<'ast> From> for DeclarationConstant<'ast> { +impl<'ast, T> From> for DeclarationConstant<'ast, T> { fn from(e: GenericIdentifier<'ast>) -> Self { DeclarationConstant::Generic(e) } } -impl<'ast> fmt::Display for DeclarationConstant<'ast> { +impl<'ast, T: fmt::Display> fmt::Display for DeclarationConstant<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { DeclarationConstant::Generic(i) => write!(f, "{}", i), DeclarationConstant::Concrete(v) => write!(f, "{}", v), DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id), + DeclarationConstant::Expression(e) => write!(f, "{}", e), } } } @@ -173,8 +184,8 @@ impl<'ast, T> From for UExpression<'ast, T> { } } -impl<'ast, T> From> for UExpression<'ast, T> { - fn from(c: DeclarationConstant<'ast>) -> Self { +impl<'ast, T> From> for UExpression<'ast, T> { + fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { DeclarationConstant::Generic(i) => { UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32) @@ -185,6 +196,7 @@ impl<'ast, T> From> for UExpression<'ast, T> { DeclarationConstant::Constant(v) => { UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32) } + DeclarationConstant::Expression(e) => e.try_into().unwrap(), } } } @@ -202,7 +214,7 @@ impl<'ast, T> TryInto for UExpression<'ast, T> { } } -impl<'ast> TryInto for DeclarationConstant<'ast> { +impl<'ast, T> TryInto for DeclarationConstant<'ast, T> { type Error = SpecializationError; fn try_into(self) -> Result { @@ -223,7 +235,7 @@ pub struct GStructMember { pub ty: Box>, } -pub type DeclarationStructMember<'ast> = GStructMember>; +pub type DeclarationStructMember<'ast, T> = GStructMember>; pub type ConcreteStructMember = GStructMember; pub type StructMember<'ast, T> = GStructMember>; @@ -263,7 +275,7 @@ pub struct GArrayType { pub ty: Box>, } -pub type DeclarationArrayType<'ast> = GArrayType>; +pub type DeclarationArrayType<'ast, T> = GArrayType>; pub type ConcreteArrayType = GArrayType; pub type ArrayType<'ast, T> = GArrayType>; @@ -329,7 +341,7 @@ pub struct StructLocation { pub name: String, } -impl<'ast> From for DeclarationArrayType<'ast> { +impl<'ast, T> From for DeclarationArrayType<'ast, T> { fn from(t: ConcreteArrayType) -> Self { try_from_g_array_type(t).unwrap() } @@ -345,7 +357,7 @@ pub struct GStructType { pub members: Vec>, } -pub type DeclarationStructType<'ast> = GStructType>; +pub type DeclarationStructType<'ast, T> = GStructType>; pub type ConcreteStructType = GStructType; pub type StructType<'ast, T> = GStructType>; @@ -409,7 +421,7 @@ impl<'ast, T> From for StructType<'ast, T> { } } -impl<'ast> From for DeclarationStructType<'ast> { +impl<'ast, T> From for DeclarationStructType<'ast, T> { fn from(t: ConcreteStructType) -> Self { try_from_g_struct_type(t).unwrap() } @@ -602,7 +614,7 @@ impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType { } } -pub type DeclarationType<'ast> = GType>; +pub type DeclarationType<'ast, T> = GType>; pub type ConcreteType = GType; pub type Type<'ast, T> = GType>; @@ -645,7 +657,7 @@ impl<'ast, T> From for Type<'ast, T> { } } -impl<'ast> From for DeclarationType<'ast> { +impl<'ast, T> From for DeclarationType<'ast, T> { fn from(t: ConcreteType) -> Self { try_from_g_type(t).unwrap() } @@ -731,7 +743,7 @@ impl GType { } impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> { - pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool { + pub fn can_be_specialized_to(&self, other: &DeclarationType<'ast, T>) -> bool { use self::GType::*; if other == self { @@ -804,14 +816,14 @@ impl ConcreteType { pub type FunctionIdentifier<'ast> = &'ast str; -#[derive(PartialEq, Eq, Hash, Debug, Clone)] +#[derive(PartialEq, Eq, Hash, Debug, Clone, PartialOrd, Ord)] pub struct GFunctionKey<'ast, S> { pub module: OwnedTypedModuleId, pub id: FunctionIdentifier<'ast>, pub signature: GSignature, } -pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, DeclarationConstant<'ast>>; +pub type DeclarationFunctionKey<'ast, T> = GFunctionKey<'ast, DeclarationConstant<'ast, T>>; pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>; pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>; @@ -821,7 +833,7 @@ impl<'ast, S: fmt::Display> fmt::Display for GFunctionKey<'ast, S> { } } -#[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] pub struct GGenericsAssignment<'ast, S>(pub BTreeMap, S>); pub type ConcreteGenericsAssignment<'ast> = GGenericsAssignment<'ast, usize>; @@ -847,8 +859,8 @@ impl<'ast, S: fmt::Display> fmt::Display for GGenericsAssignment<'ast, S> { } } -impl<'ast> PartialEq> for ConcreteFunctionKey<'ast> { - fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool { +impl<'ast, T> PartialEq> for ConcreteFunctionKey<'ast> { + fn eq(&self, other: &DeclarationFunctionKey<'ast, T>) -> bool { self.module == other.module && self.id == other.id && self.signature == other.signature } } @@ -877,7 +889,7 @@ impl<'ast, T> From> for FunctionKey<'ast, T> { } } -impl<'ast> From> for DeclarationFunctionKey<'ast> { +impl<'ast, T> From> for DeclarationFunctionKey<'ast, T> { fn from(k: ConcreteFunctionKey<'ast>) -> Self { try_from_g_function_key(k).unwrap() } @@ -924,8 +936,8 @@ impl<'ast> ConcreteFunctionKey<'ast> { use std::collections::btree_map::Entry; -pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( - decl_ty: &DeclarationType<'ast>, +pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq>( + decl_ty: &DeclarationType<'ast, T>, ty: >ype, constants: &mut GGenericsAssignment<'ast, S>, ) -> bool { @@ -946,9 +958,9 @@ pub fn check_type<'ast, S: Clone + PartialEq + PartialEq>( } }, DeclarationConstant::Concrete(s0) => s1 == *s0 as usize, - // in the case of a constant, we do not know the value yet, so we optimistically assume it's correct + // in the other cases, we do not know the value yet, so we optimistically assume it's correct // if it does not match, it will be caught during inlining - DeclarationConstant::Constant(..) => true, + DeclarationConstant::Constant(..) | DeclarationConstant::Expression(..) => true, } } (DeclarationType::FieldElement, GType::FieldElement) @@ -968,7 +980,7 @@ impl<'ast, T> From> for UExpression<'ast, T> { } } -impl<'ast> From> for DeclarationConstant<'ast> { +impl<'ast, T> From> for DeclarationConstant<'ast, T> { fn from(c: CanonicalConstantIdentifier<'ast>) -> Self { DeclarationConstant::Constant(c) } @@ -976,21 +988,21 @@ impl<'ast> From> for DeclarationConstant<'ast> pub fn specialize_declaration_type< 'ast, + T, S: Clone + PartialEq + From + fmt::Debug + From>, >( - decl_ty: DeclarationType<'ast>, + decl_ty: DeclarationType<'ast, T>, generics: &GGenericsAssignment<'ast, S>, ) -> Result, GenericIdentifier<'ast>> { Ok(match decl_ty { DeclarationType::Int => unreachable!(), DeclarationType::Array(t0) => { - // let s1 = t1.size.clone(); - let ty = box specialize_declaration_type(*t0.ty, &generics)?; let size = match t0.size { DeclarationConstant::Generic(s) => generics.0.get(&s).cloned().ok_or(s), DeclarationConstant::Concrete(s) => Ok(s.into()), DeclarationConstant::Constant(c) => Ok(c.into()), + DeclarationConstant::Expression(..) => unreachable!("the semantic checker should not yield this DeclarationConstant variant") }?; GType::Array(GArrayType { size, ty }) @@ -1017,11 +1029,8 @@ pub fn specialize_declaration_type< generics.0.get(&s).cloned().ok_or(s).map(Some) } DeclarationConstant::Concrete(s) => Ok(Some(s.into())), - DeclarationConstant::Constant(..) => { - unreachable!( - "identifiers should have been removed in constant inlining" - ) - } + DeclarationConstant::Constant(c) => Ok(Some(c.into())), + DeclarationConstant::Expression(..) => unreachable!("the semantic checker should not yield this DeclarationConstant variant"), }, _ => Ok(None), }) @@ -1085,12 +1094,12 @@ pub mod signature { } } - pub type DeclarationSignature<'ast> = GSignature>; + pub type DeclarationSignature<'ast, T> = GSignature>; pub type ConcreteSignature = GSignature; pub type Signature<'ast, T> = GSignature>; - impl<'ast> PartialEq> for ConcreteSignature { - fn eq(&self, other: &DeclarationSignature<'ast>) -> bool { + impl<'ast, T> PartialEq> for ConcreteSignature { + fn eq(&self, other: &DeclarationSignature<'ast, T>) -> bool { // we keep track of the value of constants in a map, as a given constant can only have one value let mut constants = ConcreteGenericsAssignment::default(); @@ -1099,11 +1108,11 @@ pub mod signature { .iter() .chain(other.outputs.iter()) .zip(self.inputs.iter().chain(self.outputs.iter())) - .all(|(decl_ty, ty)| check_type::(decl_ty, ty, &mut constants)) + .all(|(decl_ty, ty)| check_type::(decl_ty, ty, &mut constants)) } } - impl<'ast> DeclarationSignature<'ast> { + impl<'ast, T: Clone + PartialEq + fmt::Debug> DeclarationSignature<'ast, T> { pub fn specialize( &self, values: Vec>, @@ -1144,7 +1153,7 @@ pub mod signature { } } - pub fn get_output_types( + pub fn get_output_types( &self, generics: Vec>>, inputs: Vec>, @@ -1223,7 +1232,7 @@ pub mod signature { } } - impl<'ast> From for DeclarationSignature<'ast> { + impl<'ast, T> From for DeclarationSignature<'ast, T> { fn from(s: ConcreteSignature) -> Self { try_from_g_signature(s).unwrap() } diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index 4620dd7cf..df4e476ec 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -133,13 +133,13 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct UMetadata { pub bitwidth: Option, pub should_reduce: Option, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct UExpression<'ast, T> { pub bitwidth: UBitwidth, pub metadata: Option, @@ -173,7 +173,7 @@ impl<'ast, T> PartialEq for UExpression<'ast, T> { } } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] pub enum UExpressionInner<'ast, T> { Block(BlockExpression<'ast, T, UExpression<'ast, T>>), Identifier(Identifier<'ast>), diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index 2d19a95ef..0e9d69872 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -5,13 +5,13 @@ use crate::typed_absy::UExpression; use crate::typed_absy::{TryFrom, TryInto}; use std::fmt; -#[derive(Clone, PartialEq, Hash, Eq)] +#[derive(Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub struct GVariable<'ast, S> { pub id: Identifier<'ast>, pub _type: GType, } -pub type DeclarationVariable<'ast> = GVariable<'ast, DeclarationConstant<'ast>>; +pub type DeclarationVariable<'ast, T> = GVariable<'ast, DeclarationConstant<'ast, T>>; pub type ConcreteVariable<'ast> = GVariable<'ast, usize>; pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>; From 0d3fc27d5e2fc87f31c79980ac3897264a8d2c78 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 27 Aug 2021 16:03:38 +0200 Subject: [PATCH 27/78] clippy --- zokrates_core/src/static_analysis/reducer/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index cb2ba50e1..81ebca4cb 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -170,7 +170,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { }), )) } else { - return Err(Error::ConstantReduction(key.id.to_string(), key.module)); + Err(Error::ConstantReduction(key.id.to_string(), key.module)); } } _ => unreachable!("all constants should be local"), From 28b2a28ed3721d165a65ee086e9f6ac6873bc0f8 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 27 Aug 2021 16:05:59 +0200 Subject: [PATCH 28/78] remove U32ToField --- zokrates_core/src/embed.rs | 5 ----- zokrates_core/src/static_analysis/propagation.rs | 1 - 2 files changed, 6 deletions(-) diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 3ee07660d..2fa5e3f55 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -29,7 +29,6 @@ cfg_if::cfg_if! { #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] pub enum FlatEmbed { BitArrayLe, - U32ToField, Unpack, U8ToBits, U16ToBits, @@ -72,9 +71,6 @@ impl FlatEmbed { )), ]) .outputs(vec![DeclarationType::Boolean]), - FlatEmbed::U32ToField => DeclarationSignature::new() - .inputs(vec![DeclarationType::uint(32)]) - .outputs(vec![DeclarationType::FieldElement]), FlatEmbed::Unpack => DeclarationSignature::new() .generics(vec![Some(DeclarationConstant::Generic( GenericIdentifier { @@ -198,7 +194,6 @@ impl FlatEmbed { pub fn id(&self) -> &'static str { match self { FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", - FlatEmbed::U32ToField => "_U32_TO_FIELD", FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", FlatEmbed::U16ToBits => "_U16_TO_BITS", diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index b90fc99d6..6cab5c7c7 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -384,7 +384,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { match arguments.iter().all(|a| a.is_constant()) { true => { let r: Option> = match embed { - FlatEmbed::U32ToField => None, // todo FlatEmbed::BitArrayLe => None, // todo FlatEmbed::U64FromBits => Some(process_u_from_bits( assignees.clone(), From fce2c9f32cbd6ce11bd2c1101baa52bbea52051e Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 30 Aug 2021 15:06:21 +0200 Subject: [PATCH 29/78] flatten prog struct --- zokrates_cli/src/ops/compute_witness.rs | 7 +- zokrates_core/src/embed.rs | 8 +- zokrates_core/src/flat_absy/flat_parameter.rs | 3 +- zokrates_core/src/ir/folder.rs | 45 ++-- zokrates_core/src/ir/from_flat.rs | 67 +++--- zokrates_core/src/ir/interpreter.rs | 12 +- zokrates_core/src/ir/mod.rs | 84 +++----- zokrates_core/src/ir/serialize.rs | 20 +- zokrates_core/src/ir/smtlib2.rs | 10 +- zokrates_core/src/ir/visitor.rs | 32 ++- zokrates_core/src/optimizer/directive.rs | 21 +- zokrates_core/src/optimizer/duplicate.rs | 106 +++++----- zokrates_core/src/optimizer/redefinition.rs | 166 +++++++-------- zokrates_core/src/proof_system/ark/gm17.rs | 36 ++-- zokrates_core/src/proof_system/ark/marlin.rs | 56 +++-- zokrates_core/src/proof_system/ark/mod.rs | 49 ++--- .../src/proof_system/bellman/groth16.rs | 20 +- zokrates_core/src/proof_system/bellman/mod.rs | 192 +++++++----------- .../src/static_analysis/unconstrained_vars.rs | 63 ++---- zokrates_core/tests/wasm.rs | 20 +- 20 files changed, 411 insertions(+), 606 deletions(-) diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index f0079cc3f..675ef77d2 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -105,11 +105,8 @@ fn cli_compute(ir_prog: ir::Prog, sub_matches: &ArgMatches) -> Resu abi.signature() } false => ConcreteSignature::new() - .inputs(vec![ - ConcreteType::FieldElement; - ir_prog.main.arguments.len() - ]) - .outputs(vec![ConcreteType::FieldElement; ir_prog.main.returns.len()]), + .inputs(vec![ConcreteType::FieldElement; ir_prog.arguments.len()]) + .outputs(vec![ConcreteType::FieldElement; ir_prog.returns.len()]), }; use zokrates_abi::Inputs; diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 3ee07660d..20054e6db 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -670,11 +670,9 @@ mod tests { ) ); - let f = crate::ir::Function::from(compiled); - let prog = crate::ir::Prog { - main: f, - private: vec![true; 768], - }; + let flat_prog = crate::flat_absy::FlatProg { main: compiled }; + + let prog = crate::ir::Prog::from(flat_prog); let input: Vec<_> = (0..512) .map(|_| 0) diff --git a/zokrates_core/src/flat_absy/flat_parameter.rs b/zokrates_core/src/flat_absy/flat_parameter.rs index 0b9481cb2..c54ec5e7d 100644 --- a/zokrates_core/src/flat_absy/flat_parameter.rs +++ b/zokrates_core/src/flat_absy/flat_parameter.rs @@ -1,8 +1,9 @@ use crate::flat_absy::flat_variable::FlatVariable; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; -#[derive(Clone, PartialEq)] +#[derive(Serialize, Deserialize, Hash, Eq, PartialEq, Clone, Copy)] pub struct FlatParameter { pub id: FlatVariable, pub private: bool, diff --git a/zokrates_core/src/ir/folder.rs b/zokrates_core/src/ir/folder.rs index 4a11a04f0..6abb7b86a 100644 --- a/zokrates_core/src/ir/folder.rs +++ b/zokrates_core/src/ir/folder.rs @@ -9,11 +9,7 @@ pub trait Folder: Sized { fold_module(self, p) } - fn fold_function(&mut self, f: Function) -> Function { - fold_function(self, f) - } - - fn fold_argument(&mut self, p: FlatVariable) -> FlatVariable { + fn fold_argument(&mut self, p: FlatParameter) -> FlatParameter { fold_argument(self, p) } @@ -40,8 +36,17 @@ pub trait Folder: Sized { pub fn fold_module>(f: &mut F, p: Prog) -> Prog { Prog { - main: f.fold_function(p.main), - ..p + arguments: p + .arguments + .into_iter() + .map(|a| f.fold_argument(a)) + .collect(), + statements: p + .statements + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + returns: p.returns.into_iter().map(|v| f.fold_variable(v)).collect(), } } @@ -86,31 +91,13 @@ pub fn fold_directive>(f: &mut F, ds: Directive) -> Di } } -pub fn fold_function>(f: &mut F, fun: Function) -> Function { - Function { - arguments: fun - .arguments - .into_iter() - .map(|a| f.fold_argument(a)) - .collect(), - statements: fun - .statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - returns: fun - .returns - .into_iter() - .map(|v| f.fold_variable(v)) - .collect(), - ..fun +pub fn fold_argument>(f: &mut F, a: FlatParameter) -> FlatParameter { + FlatParameter { + id: f.fold_variable(a.id), + private: a.private, } } -pub fn fold_argument>(f: &mut F, a: FlatVariable) -> FlatVariable { - f.fold_variable(a) -} - pub fn fold_variable>(_f: &mut F, v: FlatVariable) -> FlatVariable { v } diff --git a/zokrates_core/src/ir/from_flat.rs b/zokrates_core/src/ir/from_flat.rs index c3a5dfbe2..9d8aff4eb 100644 --- a/zokrates_core/src/ir/from_flat.rs +++ b/zokrates_core/src/ir/from_flat.rs @@ -1,12 +1,28 @@ -use crate::flat_absy::{ - FlatDirective, FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable, -}; -use crate::ir::{Directive, Function, LinComb, Prog, QuadComb, Statement}; +use crate::flat_absy::{FlatDirective, FlatExpression, FlatProg, FlatStatement, FlatVariable}; +use crate::ir::{Directive, LinComb, Prog, QuadComb, Statement}; use zokrates_field::Field; -impl From> for Function { - fn from(flat_function: FlatFunction) -> Function { - let return_expressions: Vec> = flat_function +impl QuadComb { + fn from_flat_expression>>(flat_expression: U) -> QuadComb { + let flat_expression = flat_expression.into(); + match flat_expression.is_linear() { + true => LinComb::from(flat_expression).into(), + false => match flat_expression { + FlatExpression::Mult(box e1, box e2) => { + QuadComb::from_linear_combinations(e1.into(), e2.into()) + } + e => unimplemented!("{}", e), + }, + } + } +} + +impl From> for Prog { + fn from(flat_prog: FlatProg) -> Prog { + // get the main function + let main = flat_prog.main; + + let return_expressions: Vec> = main .statements .iter() .filter_map(|s| match s { @@ -15,15 +31,15 @@ impl From> for Function { }) .next() .unwrap(); - Function { - id: String::from("main"), - arguments: flat_function.arguments.into_iter().map(|p| p.id).collect(), + + Prog { + arguments: main.arguments, returns: return_expressions .iter() .enumerate() .map(|(index, _)| FlatVariable::public(index)) .collect(), - statements: flat_function + statements: main .statements .into_iter() .filter_map(|s| match s { @@ -47,35 +63,6 @@ impl From> for Function { } } -impl QuadComb { - fn from_flat_expression>>(flat_expression: U) -> QuadComb { - let flat_expression = flat_expression.into(); - match flat_expression.is_linear() { - true => LinComb::from(flat_expression).into(), - false => match flat_expression { - FlatExpression::Mult(box e1, box e2) => { - QuadComb::from_linear_combinations(e1.into(), e2.into()) - } - e => unimplemented!("{}", e), - }, - } - } -} - -impl From> for Prog { - fn from(flat_prog: FlatProg) -> Prog { - // get the main function - let main = flat_prog.main; - - // get the interface of the program, i.e. which inputs are private and public - let private = main.arguments.iter().map(|p| p.private).collect(); - - let main = main.into(); - - Prog { main, private } - } -} - impl From> for LinComb { fn from(flat_expression: FlatExpression) -> LinComb { match flat_expression { diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 580e5ee89..8b071998c 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -34,15 +34,15 @@ impl Interpreter { impl Interpreter { pub fn execute(&self, program: &Prog, inputs: &[T]) -> ExecutionResult { - let main = &program.main; self.check_inputs(&program, &inputs)?; let mut witness = BTreeMap::new(); witness.insert(FlatVariable::one(), T::one()); - for (arg, value) in main.arguments.iter().zip(inputs.iter()) { - witness.insert(*arg, value.clone()); + + for (arg, value) in program.arguments.iter().zip(inputs.iter()) { + witness.insert(arg.id, value.clone()); } - for statement in main.statements.iter() { + for statement in program.statements.iter() { match statement { Statement::Constraint(quad, lin, message) => match lin.is_assignee(&witness) { true => { @@ -120,11 +120,11 @@ impl Interpreter { } fn check_inputs(&self, program: &Prog, inputs: &[U]) -> Result<(), Error> { - if program.main.arguments.len() == inputs.len() { + if program.arguments.len() == inputs.len() { Ok(()) } else { Err(Error::WrongInputCount { - expected: program.main.arguments.len(), + expected: program.arguments.len(), received: inputs.len(), }) } diff --git a/zokrates_core/src/ir/mod.rs b/zokrates_core/src/ir/mod.rs index b39a8407d..623308a26 100644 --- a/zokrates_core/src/ir/mod.rs +++ b/zokrates_core/src/ir/mod.rs @@ -74,20 +74,40 @@ impl fmt::Display for Statement { } } -#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)] -pub struct Function { - pub id: String, +#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq, Default)] +pub struct Prog { pub statements: Vec>, - pub arguments: Vec, + pub arguments: Vec, pub returns: Vec, } -impl fmt::Display for Function { +impl Prog { + pub fn constraint_count(&self) -> usize { + self.statements + .iter() + .filter(|s| matches!(s, Statement::Constraint(..))) + .count() + } + + pub fn arguments_count(&self) -> usize { + self.arguments.len() + } + + pub fn public_inputs(&self, witness: &Witness) -> Vec { + self.arguments + .iter() + .filter(|p| !p.private) + .map(|p| witness.0.get(&p.id).unwrap().clone()) + .chain(witness.return_values()) + .collect() + } +} + +impl fmt::Display for Prog { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "def {}({}) -> ({}):\n{}\n\t return {}", - self.id, + "def main({}) -> ({}):\n{}\n\treturn {}", self.arguments .iter() .map(|v| format!("{}", v)) @@ -108,56 +128,6 @@ impl fmt::Display for Function { } } -#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)] -pub struct Prog { - pub main: Function, - pub private: Vec, -} - -impl Prog { - pub fn constraint_count(&self) -> usize { - self.main - .statements - .iter() - .filter(|s| matches!(s, Statement::Constraint(..))) - .count() - } - - pub fn arguments_count(&self) -> usize { - self.private.len() - } - - pub fn parameters(&self) -> Vec { - self.main - .arguments - .iter() - .zip(self.private.iter()) - .map(|(id, private)| FlatParameter { - private: *private, - id: *id, - }) - .collect() - } - - pub fn public_inputs(&self, witness: &Witness) -> Vec { - self.main - .arguments - .clone() - .iter() - .zip(self.private.iter()) - .filter(|(_, p)| !**p) - .map(|(v, _)| witness.0.get(v).unwrap().clone()) - .chain(witness.return_values()) - .collect() - } -} - -impl fmt::Display for Prog { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.main) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/zokrates_core/src/ir/serialize.rs b/zokrates_core/src/ir/serialize.rs index d97220a4b..0f247f5e4 100644 --- a/zokrates_core/src/ir/serialize.rs +++ b/zokrates_core/src/ir/serialize.rs @@ -76,15 +76,7 @@ mod tests { #[test] fn ser_deser_v1() { - let p: ir::Prog = ir::Prog { - main: ir::Function { - arguments: vec![], - id: "something".to_string(), - returns: vec![], - statements: vec![], - }, - private: vec![], - }; + let p: ir::Prog = ir::Prog::default(); let mut buffer = Cursor::new(vec![]); p.serialize(&mut buffer); @@ -97,15 +89,7 @@ mod tests { assert_eq!(ProgEnum::Bn128Program(p), deserialized_p); - let p: ir::Prog = ir::Prog { - main: ir::Function { - arguments: vec![], - id: "something".to_string(), - returns: vec![], - statements: vec![], - }, - private: vec![], - }; + let p: ir::Prog = ir::Prog::default(); let mut buffer = Cursor::new(vec![]); p.serialize(&mut buffer); diff --git a/zokrates_core/src/ir/smtlib2.rs b/zokrates_core/src/ir/smtlib2.rs index 99b3be28f..560a1b909 100644 --- a/zokrates_core/src/ir/smtlib2.rs +++ b/zokrates_core/src/ir/smtlib2.rs @@ -12,12 +12,6 @@ pub trait SMTLib2 { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result; } -impl SMTLib2 for Prog { - fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.main.to_smtlib2(f) - } -} - pub struct SMTLib2Display<'a, T>(pub &'a Prog); impl fmt::Display for SMTLib2Display<'_, T> { @@ -36,12 +30,12 @@ impl Visitor for FlatVariableCollector { } } -impl SMTLib2 for Function { +impl SMTLib2 for Prog { fn to_smtlib2(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut collector = FlatVariableCollector { variables: BTreeSet::::new(), }; - collector.visit_function(self); + collector.visit_module(self); collector.variables.insert(FlatVariable::one()); writeln!(f, "; Auto generated by ZoKrates")?; diff --git a/zokrates_core/src/ir/visitor.rs b/zokrates_core/src/ir/visitor.rs index 4021ffc74..2a9cc028e 100644 --- a/zokrates_core/src/ir/visitor.rs +++ b/zokrates_core/src/ir/visitor.rs @@ -9,11 +9,7 @@ pub trait Visitor: Sized { visit_module(self, p) } - fn visit_function(&mut self, f: &Function) { - visit_function(self, f) - } - - fn visit_argument(&mut self, p: &FlatVariable) { + fn visit_argument(&mut self, p: &FlatParameter) { visit_argument(self, p) } @@ -47,7 +43,15 @@ pub trait Visitor: Sized { } pub fn visit_module>(f: &mut F, p: &Prog) { - f.visit_function(&p.main) + for expr in p.arguments.iter() { + f.visit_argument(expr); + } + for expr in p.statements.iter() { + f.visit_statement(expr); + } + for expr in p.returns.iter() { + f.visit_variable(expr); + } } pub fn visit_statement>(f: &mut F, s: &Statement) { @@ -84,20 +88,8 @@ pub fn visit_directive>(f: &mut F, ds: &Directive) { } } -pub fn visit_function>(f: &mut F, fun: &Function) { - for expr in fun.arguments.iter() { - f.visit_argument(expr); - } - for expr in fun.statements.iter() { - f.visit_statement(expr); - } - for expr in fun.returns.iter() { - f.visit_variable(expr); - } -} - -pub fn visit_argument>(f: &mut F, a: &FlatVariable) { - f.visit_variable(a) +pub fn visit_argument>(f: &mut F, a: &FlatParameter) { + f.visit_variable(&a.id) } pub fn visit_variable>(_f: &mut F, _v: &FlatVariable) {} diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index d84a72b35..2b4e8a4b6 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -16,6 +16,7 @@ use crate::optimizer::canonicalizer::Canonicalizer; use crate::solvers::Solver; use std::collections::hash_map::{Entry, HashMap}; use zokrates_field::Field; + #[derive(Debug)] pub struct DirectiveOptimizer { calls: HashMap<(Solver, Vec>), Vec>, @@ -37,21 +38,25 @@ impl DirectiveOptimizer { } impl Folder for DirectiveOptimizer { - fn fold_function(&mut self, f: Function) -> Function { - // in order to correcty identify duplicates, we need to first canonicalize the statements + fn fold_module(&mut self, p: Prog) -> Prog { + // in order to correctly identify duplicates, we need to first canonicalize the statements let mut canonicalizer = Canonicalizer; - let f = Function { - statements: f + let p = Prog { + statements: p .statements .into_iter() .flat_map(|s| canonicalizer.fold_statement(s)) .collect(), - ..f + ..p }; - fold_function(self, f) + fold_module(self, p) + } + + fn fold_variable(&mut self, v: FlatVariable) -> FlatVariable { + *self.substitution.get(&v).unwrap_or(&v) } fn fold_statement(&mut self, s: Statement) -> Vec> { @@ -74,10 +79,6 @@ impl Folder for DirectiveOptimizer { s => fold_statement(self, s), } } - - fn fold_variable(&mut self, v: FlatVariable) -> FlatVariable { - *self.substitution.get(&v).unwrap_or(&v) - } } #[cfg(test)] diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index 73c060787..493fe4e8f 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -34,20 +34,20 @@ impl DuplicateOptimizer { } impl Folder for DuplicateOptimizer { - fn fold_function(&mut self, f: Function) -> Function { - // in order to correcty identify duplicates, we need to first canonicalize the statements + fn fold_module(&mut self, p: Prog) -> Prog { + // in order to correctly identify duplicates, we need to first canonicalize the statements let mut canonicalizer = Canonicalizer; - let f = Function { - statements: f + let p = Prog { + statements: p .statements .into_iter() .flat_map(|s| canonicalizer.fold_statement(s)) .collect(), - ..f + ..p }; - fold_function(self, f) + fold_module(self, p) } fn fold_statement(&mut self, s: Statement) -> Vec> { @@ -71,28 +71,24 @@ mod tests { #[test] fn identity() { let p: Prog = Prog { - private: vec![], - main: Function { - id: "main".to_string(), - statements: vec![ - Statement::constraint( - QuadComb::from_linear_combinations( - LinComb::summand(3, FlatVariable::new(3)), - LinComb::summand(3, FlatVariable::new(3)), - ), - LinComb::one(), + statements: vec![ + Statement::constraint( + QuadComb::from_linear_combinations( + LinComb::summand(3, FlatVariable::new(3)), + LinComb::summand(3, FlatVariable::new(3)), ), - Statement::constraint( - QuadComb::from_linear_combinations( - LinComb::summand(3, FlatVariable::new(42)), - LinComb::summand(3, FlatVariable::new(3)), - ), - LinComb::zero(), + LinComb::one(), + ), + Statement::constraint( + QuadComb::from_linear_combinations( + LinComb::summand(3, FlatVariable::new(42)), + LinComb::summand(3, FlatVariable::new(3)), ), - ], - returns: vec![], - arguments: vec![], - }, + LinComb::zero(), + ), + ], + returns: vec![], + arguments: vec![], }; let expected = p.clone(); @@ -111,44 +107,36 @@ mod tests { ); let p: Prog = Prog { - private: vec![], - main: Function { - id: "main".to_string(), - statements: vec![ - constraint.clone(), - constraint.clone(), - Statement::constraint( - QuadComb::from_linear_combinations( - LinComb::summand(3, FlatVariable::new(42)), - LinComb::summand(3, FlatVariable::new(3)), - ), - LinComb::zero(), + statements: vec![ + constraint.clone(), + constraint.clone(), + Statement::constraint( + QuadComb::from_linear_combinations( + LinComb::summand(3, FlatVariable::new(42)), + LinComb::summand(3, FlatVariable::new(3)), ), - constraint.clone(), - constraint.clone(), - ], - returns: vec![], - arguments: vec![], - }, + LinComb::zero(), + ), + constraint.clone(), + constraint.clone(), + ], + returns: vec![], + arguments: vec![], }; let expected = Prog { - private: vec![], - main: Function { - id: "main".to_string(), - statements: vec![ - constraint, - Statement::constraint( - QuadComb::from_linear_combinations( - LinComb::summand(3, FlatVariable::new(42)), - LinComb::summand(3, FlatVariable::new(3)), - ), - LinComb::zero(), + statements: vec![ + constraint, + Statement::constraint( + QuadComb::from_linear_combinations( + LinComb::summand(3, FlatVariable::new(42)), + LinComb::summand(3, FlatVariable::new(3)), ), - ], - returns: vec![], - arguments: vec![], - }, + LinComb::zero(), + ), + ], + returns: vec![], + arguments: vec![], }; assert_eq!(DuplicateOptimizer::optimize(p), expected); diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index 4222281c0..28bab99e8 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -37,7 +37,8 @@ // - otherwise return `c_0` use crate::flat_absy::flat_variable::FlatVariable; -use crate::ir::folder::{fold_function, Folder}; +use crate::flat_absy::FlatParameter; +use crate::ir::folder::{fold_module, Folder}; use crate::ir::LinComb; use crate::ir::*; use std::collections::{HashMap, HashSet}; @@ -65,6 +66,22 @@ impl RedefinitionOptimizer { } impl Folder for RedefinitionOptimizer { + fn fold_module(&mut self, p: Prog) -> Prog { + // to prevent the optimiser from replacing outputs, add them to the ignored set + self.ignore.extend(p.returns.iter().cloned()); + + // to prevent the optimiser from replacing ~one, add it to the ignored set + self.ignore.insert(FlatVariable::one()); + + fold_module(self, p) + } + + fn fold_argument(&mut self, a: FlatParameter) -> FlatParameter { + // to prevent the optimiser from replacing user input, add it to the ignored set + self.ignore.insert(a.id); + a + } + fn fold_statement(&mut self, s: Statement) -> Vec> { match s { Statement::Constraint(quad, lin, message) => { @@ -191,27 +208,12 @@ impl Folder for RedefinitionOptimizer { false => lc, } } - - fn fold_argument(&mut self, a: FlatVariable) -> FlatVariable { - // to prevent the optimiser from replacing user input, add it to the ignored set - self.ignore.insert(a); - a - } - - fn fold_function(&mut self, fun: Function) -> Function { - // to prevent the optimiser from replacing outputs, add them to the ignored set - self.ignore.extend(fun.returns.iter().cloned()); - - // to prevent the optimiser from replacing ~one, add it to the ignored set - self.ignore.insert(FlatVariable::one()); - - fold_function(self, fun) - } } #[cfg(test)] mod tests { use super::*; + use crate::flat_absy::FlatParameter; use zokrates_field::Bn128Field; #[test] @@ -221,26 +223,24 @@ mod tests { // z = y // return z - let x = FlatVariable::new(0); + let x = FlatParameter::public(FlatVariable::new(0)); let y = FlatVariable::new(1); let z = FlatVariable::new(2); - let f: Function = Function { - id: "foo".to_string(), + let p: Prog = Prog { arguments: vec![x], - statements: vec![Statement::definition(y, x), Statement::definition(z, y)], + statements: vec![Statement::definition(y, x.id), Statement::definition(z, y)], returns: vec![z], }; - let optimized: Function = Function { - id: "foo".to_string(), + let optimized: Prog = Prog { arguments: vec![x], - statements: vec![Statement::definition(z, x)], + statements: vec![Statement::definition(z, x.id)], returns: vec![z], }; let mut optimizer = RedefinitionOptimizer::new(); - assert_eq!(optimizer.fold_function(f), optimized); + assert_eq!(optimizer.fold_module(p), optimized); } #[test] @@ -250,19 +250,18 @@ mod tests { // return one let one = FlatVariable::one(); - let x = FlatVariable::new(1); + let x = FlatParameter::public(FlatVariable::new(0)); - let f: Function = Function { - id: "foo".to_string(), + let p: Prog = Prog { arguments: vec![x], - statements: vec![Statement::definition(one, x)], - returns: vec![x], + statements: vec![Statement::definition(one, x.id)], + returns: vec![x.id], }; - let optimized = f.clone(); + let optimized = p.clone(); let mut optimizer = RedefinitionOptimizer::new(); - assert_eq!(optimizer.fold_function(f), optimized); + assert_eq!(optimizer.fold_module(p), optimized); } #[test] @@ -279,30 +278,31 @@ mod tests { // x == x // will be eliminated as a tautology // return x - let x = FlatVariable::new(0); + let x = FlatParameter::public(FlatVariable::new(0)); let y = FlatVariable::new(1); let z = FlatVariable::new(2); - let f: Function = Function { - id: "foo".to_string(), + let p: Prog = Prog { arguments: vec![x], statements: vec![ - Statement::definition(y, x), + Statement::definition(y, x.id), Statement::definition(z, y), Statement::constraint(z, y), ], returns: vec![z], }; - let optimized: Function = Function { - id: "foo".to_string(), + let optimized: Prog = Prog { arguments: vec![x], - statements: vec![Statement::definition(z, x), Statement::constraint(z, x)], + statements: vec![ + Statement::definition(z, x.id), + Statement::constraint(z, x.id), + ], returns: vec![z], }; let mut optimizer = RedefinitionOptimizer::new(); - assert_eq!(optimizer.fold_function(f), optimized); + assert_eq!(optimizer.fold_module(p), optimized); } #[test] @@ -319,17 +319,16 @@ mod tests { // def main(x): // return x, 1 - let x = FlatVariable::new(0); + let x = FlatParameter::public(FlatVariable::new(0)); let y = FlatVariable::new(1); let z = FlatVariable::new(2); let t = FlatVariable::new(3); let w = FlatVariable::new(4); - let f: Function = Function { - id: "foo".to_string(), + let p: Prog = Prog { arguments: vec![x], statements: vec![ - Statement::definition(y, x), + Statement::definition(y, x.id), Statement::definition(t, Bn128Field::from(1)), Statement::definition(z, y), Statement::definition(w, t), @@ -337,11 +336,10 @@ mod tests { returns: vec![z, w], }; - let optimized: Function = Function { - id: "foo".to_string(), + let optimized: Prog = Prog { arguments: vec![x], statements: vec![ - Statement::definition(z, x), + Statement::definition(z, x.id), Statement::definition(w, Bn128Field::from(1)), ], returns: vec![z, w], @@ -349,7 +347,7 @@ mod tests { let mut optimizer = RedefinitionOptimizer::new(); - assert_eq!(optimizer.fold_function(f), optimized); + assert_eq!(optimizer.fold_module(p), optimized); } #[test] @@ -368,45 +366,49 @@ mod tests { // 1*x + 1*y + 2*x + 2*y + 3*x + 3*y == 6*x + 6*y // will be eliminated as a tautology // return 6*x + 6*y - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); + let x = FlatParameter::public(FlatVariable::new(0)); + let y = FlatParameter::public(FlatVariable::new(1)); let a = FlatVariable::new(2); let b = FlatVariable::new(3); let c = FlatVariable::new(4); let r = FlatVariable::new(5); - let f: Function = Function { - id: "foo".to_string(), + let p: Prog = Prog { arguments: vec![x, y], statements: vec![ - Statement::definition(a, LinComb::from(x) + LinComb::from(y)), - Statement::definition(b, LinComb::from(a) + LinComb::from(x) + LinComb::from(y)), - Statement::definition(c, LinComb::from(b) + LinComb::from(x) + LinComb::from(y)), + Statement::definition(a, LinComb::from(x.id) + LinComb::from(y.id)), + Statement::definition( + b, + LinComb::from(a) + LinComb::from(x.id) + LinComb::from(y.id), + ), + Statement::definition( + c, + LinComb::from(b) + LinComb::from(x.id) + LinComb::from(y.id), + ), Statement::constraint( LinComb::summand(2, c), - LinComb::summand(6, x) + LinComb::summand(6, y), + LinComb::summand(6, x.id) + LinComb::summand(6, y.id), ), Statement::definition(r, LinComb::from(a) + LinComb::from(b) + LinComb::from(c)), ], returns: vec![r], }; - let expected: Function = Function { - id: "foo".to_string(), + let expected: Prog = Prog { arguments: vec![x, y], statements: vec![ Statement::constraint( - LinComb::summand(6, x) + LinComb::summand(6, y), - LinComb::summand(6, x) + LinComb::summand(6, y), + LinComb::summand(6, x.id) + LinComb::summand(6, y.id), + LinComb::summand(6, x.id) + LinComb::summand(6, y.id), ), Statement::definition( r, - LinComb::summand(1, x) - + LinComb::summand(1, y) - + LinComb::summand(2, x) - + LinComb::summand(2, y) - + LinComb::summand(3, x) - + LinComb::summand(3, y), + LinComb::summand(1, x.id) + + LinComb::summand(1, y.id) + + LinComb::summand(2, x.id) + + LinComb::summand(2, y.id) + + LinComb::summand(3, x.id) + + LinComb::summand(3, y.id), ), ], returns: vec![r], @@ -414,7 +416,7 @@ mod tests { let mut optimizer = RedefinitionOptimizer::new(); - let optimized = optimizer.fold_function(f); + let optimized = optimizer.fold_module(p); assert_eq!(optimized, expected); } @@ -433,27 +435,26 @@ mod tests { // z = x // return - let x = FlatVariable::new(0); - let y = FlatVariable::new(1); + let x = FlatParameter::public(FlatVariable::new(0)); + let y = FlatParameter::public(FlatVariable::new(1)); let z = FlatVariable::new(2); - let f: Function = Function { - id: "main".to_string(), + let p: Prog = Prog { arguments: vec![x, y], statements: vec![ Statement::definition( z, - QuadComb::from_linear_combinations(LinComb::from(x), LinComb::from(y)), + QuadComb::from_linear_combinations(LinComb::from(x.id), LinComb::from(y.id)), ), - Statement::definition(z, LinComb::from(x)), + Statement::definition(z, LinComb::from(x.id)), ], returns: vec![], }; - let optimized = f.clone(); + let optimized = p.clone(); let mut optimizer = RedefinitionOptimizer::new(); - assert_eq!(optimizer.fold_function(f), optimized); + assert_eq!(optimizer.fold_module(p), optimized); } #[test] @@ -467,21 +468,20 @@ mod tests { // unchanged - let x = FlatVariable::new(0); + let x = FlatParameter::public(FlatVariable::new(0)); - let f: Function = Function { - id: "foo".to_string(), + let p: Prog = Prog { arguments: vec![x], statements: vec![ - Statement::constraint(x, Bn128Field::from(1)), - Statement::constraint(x, Bn128Field::from(2)), + Statement::constraint(x.id, Bn128Field::from(1)), + Statement::constraint(x.id, Bn128Field::from(2)), ], - returns: vec![x], + returns: vec![x.id], }; - let optimized = f.clone(); + let optimized = p.clone(); let mut optimizer = RedefinitionOptimizer::new(); - assert_eq!(optimizer.fold_function(f), optimized); + assert_eq!(optimizer.fold_module(p), optimized); } } diff --git a/zokrates_core/src/proof_system/ark/gm17.rs b/zokrates_core/src/proof_system/ark/gm17.rs index b79a90a60..4c54b261f 100644 --- a/zokrates_core/src/proof_system/ark/gm17.rs +++ b/zokrates_core/src/proof_system/ark/gm17.rs @@ -250,8 +250,8 @@ pub mod serialization { #[cfg(test)] mod tests { - use crate::flat_absy::FlatVariable; - use crate::ir::{Function, Interpreter, Prog, Statement}; + use crate::flat_absy::{FlatParameter, FlatVariable}; + use crate::ir::{Interpreter, Prog, Statement}; use super::*; use zokrates_field::{Bls12_377Field, Bw6_761Field}; @@ -259,16 +259,12 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![false], + arguments: vec![FlatParameter::public(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let keypair = >::setup(program.clone()); @@ -288,16 +284,12 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![false], + arguments: vec![FlatParameter::public(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let keypair = >::setup(program.clone()); diff --git a/zokrates_core/src/proof_system/ark/marlin.rs b/zokrates_core/src/proof_system/ark/marlin.rs index 1384e4d94..f49eb87c7 100644 --- a/zokrates_core/src/proof_system/ark/marlin.rs +++ b/zokrates_core/src/proof_system/ark/marlin.rs @@ -188,8 +188,8 @@ impl Backend for Ark { #[cfg(test)] mod tests { - use crate::flat_absy::FlatVariable; - use crate::ir::{Function, Interpreter, Prog, QuadComb, Statement}; + use crate::flat_absy::{FlatParameter, FlatVariable}; + use crate::ir::{Interpreter, Prog, QuadComb, Statement}; use super::*; use crate::proof_system::scheme::Marlin; @@ -198,22 +198,18 @@ mod tests { #[test] fn verify_bls12_377_field() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![ - Statement::constraint( - QuadComb::from_linear_combinations( - FlatVariable::new(0).into(), - FlatVariable::new(0).into(), - ), - FlatVariable::new(1), + arguments: vec![FlatParameter::private(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![ + Statement::constraint( + QuadComb::from_linear_combinations( + FlatVariable::new(0).into(), + FlatVariable::new(0).into(), ), - Statement::constraint(FlatVariable::new(1), FlatVariable::public(0)), - ], - }, - private: vec![true], + FlatVariable::new(1), + ), + Statement::constraint(FlatVariable::new(1), FlatVariable::public(0)), + ], }; let srs = >::universal_setup(5); @@ -235,22 +231,18 @@ mod tests { #[test] fn verify_bw6_761_field() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![ - Statement::constraint( - QuadComb::from_linear_combinations( - FlatVariable::new(0).into(), - FlatVariable::new(0).into(), - ), - FlatVariable::new(1), + arguments: vec![FlatParameter::private(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![ + Statement::constraint( + QuadComb::from_linear_combinations( + FlatVariable::new(0).into(), + FlatVariable::new(0).into(), ), - Statement::constraint(FlatVariable::new(1), FlatVariable::public(0)), - ], - }, - private: vec![true], + FlatVariable::new(1), + ), + Statement::constraint(FlatVariable::new(1), FlatVariable::public(0)), + ], }; let srs = >::universal_setup(5); diff --git a/zokrates_core/src/proof_system/ark/mod.rs b/zokrates_core/src/proof_system/ark/mod.rs index 11b25ddbd..8527f3bd4 100644 --- a/zokrates_core/src/proof_system/ark/mod.rs +++ b/zokrates_core/src/proof_system/ark/mod.rs @@ -95,37 +95,28 @@ impl Prog { match cs { ConstraintSystemRef::CS(rc) => { let mut cs = rc.borrow_mut(); - symbols.extend( - self.main - .arguments - .iter() - .zip(self.private) - .enumerate() - .map(|(_, (var, private))| { - let wire = match private { - true => cs.new_witness_variable(|| { - Ok(witness - .0 - .remove(&var) - .ok_or(SynthesisError::AssignmentMissing)? - .into_ark()) - }), - false => cs.new_input_variable(|| { - Ok(witness - .0 - .remove(&var) - .ok_or(SynthesisError::AssignmentMissing)? - .into_ark()) - }), - } - .unwrap(); - (*var, wire) + symbols.extend(self.arguments.iter().enumerate().map(|(_, p)| { + let wire = match p.private { + true => cs.new_witness_variable(|| { + Ok(witness + .0 + .remove(&p.id) + .ok_or(SynthesisError::AssignmentMissing)? + .into_ark()) }), - ); - - let main = self.main; + false => cs.new_input_variable(|| { + Ok(witness + .0 + .remove(&p.id) + .ok_or(SynthesisError::AssignmentMissing)? + .into_ark()) + }), + } + .unwrap(); + (p.id, wire) + })); - for statement in main.statements { + for statement in self.statements { if let Statement::Constraint(quad, lin, _) = statement { let a = ark_combination( quad.left.clone().into_canonical(), diff --git a/zokrates_core/src/proof_system/bellman/groth16.rs b/zokrates_core/src/proof_system/bellman/groth16.rs index 24cb544b1..1c22dc5ca 100644 --- a/zokrates_core/src/proof_system/bellman/groth16.rs +++ b/zokrates_core/src/proof_system/bellman/groth16.rs @@ -136,22 +136,18 @@ mod tests { use zokrates_field::Bn128Field; use super::*; - use crate::flat_absy::FlatVariable; - use crate::ir::{Function, Interpreter, Prog, Statement}; + use crate::flat_absy::{FlatParameter, FlatVariable}; + use crate::ir::{Interpreter, Prog, Statement}; #[test] fn verify() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![false], + arguments: vec![FlatParameter::public(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let keypair = >::setup(program.clone()); diff --git a/zokrates_core/src/proof_system/bellman/mod.rs b/zokrates_core/src/proof_system/bellman/mod.rs index f8036c75a..84a63bab2 100644 --- a/zokrates_core/src/proof_system/bellman/mod.rs +++ b/zokrates_core/src/proof_system/bellman/mod.rs @@ -94,43 +94,34 @@ impl Prog { assert!(symbols.insert(FlatVariable::one(), CS::one()).is_none()); - symbols.extend( - self.main - .arguments - .iter() - .zip(self.private) - .enumerate() - .map(|(index, (var, private))| { - let wire = match private { - true => cs.alloc( - || format!("PRIVATE_INPUT_{}", index), - || { - Ok(witness - .0 - .remove(&var) - .ok_or(SynthesisError::AssignmentMissing)? - .into_bellman()) - }, - ), - false => cs.alloc_input( - || format!("PUBLIC_INPUT_{}", index), - || { - Ok(witness - .0 - .remove(&var) - .ok_or(SynthesisError::AssignmentMissing)? - .into_bellman()) - }, - ), - } - .unwrap(); - (*var, wire) - }), - ); - - let main = self.main; + symbols.extend(self.arguments.iter().enumerate().map(|(index, p)| { + let wire = match p.private { + true => cs.alloc( + || format!("PRIVATE_INPUT_{}", index), + || { + Ok(witness + .0 + .remove(&p.id) + .ok_or(SynthesisError::AssignmentMissing)? + .into_bellman()) + }, + ), + false => cs.alloc_input( + || format!("PUBLIC_INPUT_{}", index), + || { + Ok(witness + .0 + .remove(&p.id) + .ok_or(SynthesisError::AssignmentMissing)? + .into_bellman()) + }, + ), + } + .unwrap(); + (p.id, wire) + })); - for statement in main.statements { + for statement in self.statements { if let Statement::Constraint(quad, lin, _) = statement { let a = &bellman_combination( quad.left.into_canonical(), @@ -270,23 +261,16 @@ mod parse { mod tests { use super::*; use crate::ir::Interpreter; - use crate::ir::{Function, LinComb}; + use crate::ir::LinComb; use zokrates_field::Bn128Field; mod prove { use super::*; + use crate::flat_absy::FlatParameter; #[test] fn empty() { - let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![], - returns: vec![], - statements: vec![], - }, - private: vec![], - }; + let program: Prog = Prog::default(); let interpreter = Interpreter::default(); @@ -300,16 +284,12 @@ mod tests { #[test] fn identity() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![true], + arguments: vec![FlatParameter::private(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let interpreter = Interpreter::default(); @@ -327,16 +307,12 @@ mod tests { #[test] fn public_identity() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![false], + arguments: vec![FlatParameter::public(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let interpreter = Interpreter::default(); @@ -354,16 +330,12 @@ mod tests { #[test] fn no_arguments() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::one(), - FlatVariable::public(0), - )], - }, - private: vec![], + arguments: vec![], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::one(), + FlatVariable::public(0), + )], }; let interpreter = Interpreter::default(); @@ -380,24 +352,21 @@ mod tests { // public variables must be ordered from 0 // private variables can be unordered let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(42), FlatVariable::new(51)], - returns: vec![FlatVariable::public(0), FlatVariable::public(1)], - statements: vec![ - Statement::constraint( - LinComb::from(FlatVariable::new(42)) - + LinComb::from(FlatVariable::new(51)), - FlatVariable::public(0), - ), - Statement::constraint( - LinComb::from(FlatVariable::one()) - + LinComb::from(FlatVariable::new(42)), - FlatVariable::public(1), - ), - ], - }, - private: vec![true, false], + arguments: vec![ + FlatParameter::private(FlatVariable::new(42)), + FlatParameter::public(FlatVariable::new(51)), + ], + returns: vec![FlatVariable::public(0), FlatVariable::public(1)], + statements: vec![ + Statement::constraint( + LinComb::from(FlatVariable::new(42)) + LinComb::from(FlatVariable::new(51)), + FlatVariable::public(0), + ), + Statement::constraint( + LinComb::from(FlatVariable::one()) + LinComb::from(FlatVariable::new(42)), + FlatVariable::public(1), + ), + ], }; let interpreter = Interpreter::default(); @@ -414,16 +383,12 @@ mod tests { #[test] fn one() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(42)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - LinComb::from(FlatVariable::new(42)) + LinComb::one(), - FlatVariable::public(0), - )], - }, - private: vec![false], + arguments: vec![FlatParameter::public(FlatVariable::new(42))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + LinComb::from(FlatVariable::new(42)) + LinComb::one(), + FlatVariable::public(0), + )], }; let interpreter = Interpreter::default(); @@ -441,16 +406,15 @@ mod tests { #[test] fn with_directives() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(42), FlatVariable::new(51)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - LinComb::from(FlatVariable::new(42)) + LinComb::from(FlatVariable::new(51)), - FlatVariable::public(0), - )], - }, - private: vec![true, false], + arguments: vec![ + FlatParameter::private(FlatVariable::new(42)), + FlatParameter::public(FlatVariable::new(51)), + ], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + LinComb::from(FlatVariable::new(42)) + LinComb::from(FlatVariable::new(51)), + FlatVariable::public(0), + )], }; let interpreter = Interpreter::default(); diff --git a/zokrates_core/src/static_analysis/unconstrained_vars.rs b/zokrates_core/src/static_analysis/unconstrained_vars.rs index 5eb0b80eb..c285a4e47 100644 --- a/zokrates_core/src/static_analysis/unconstrained_vars.rs +++ b/zokrates_core/src/static_analysis/unconstrained_vars.rs @@ -1,4 +1,4 @@ -use crate::flat_absy::FlatVariable; +use crate::flat_absy::{FlatParameter, FlatVariable}; use crate::ir::visitor::Visitor; use crate::ir::Directive; use crate::ir::Prog; @@ -6,7 +6,7 @@ use std::collections::HashSet; use std::fmt; use zokrates_field::Field; -#[derive(Debug)] +#[derive(Debug, Default)] pub struct UnconstrainedVariableDetector { pub(self) variables: HashSet, } @@ -26,19 +26,8 @@ impl fmt::Display for Error { } impl UnconstrainedVariableDetector { - fn new(p: &Prog) -> Self { - UnconstrainedVariableDetector { - variables: p - .parameters() - .iter() - .filter(|p| p.private) - .map(|p| p.id) - .collect(), - } - } - pub fn detect(p: Prog) -> Result, Error> { - let mut instance = Self::new(&p); + let mut instance = Self::default(); instance.visit_module(&p); if instance.variables.is_empty() { @@ -50,7 +39,11 @@ impl UnconstrainedVariableDetector { } impl Visitor for UnconstrainedVariableDetector { - fn visit_argument(&mut self, _: &FlatVariable) {} + fn visit_argument(&mut self, p: &FlatParameter) { + if p.private { + self.variables.insert(p.id); + } + } fn visit_variable(&mut self, v: &FlatVariable) { self.variables.remove(v); } @@ -63,7 +56,7 @@ impl Visitor for UnconstrainedVariableDetector { mod tests { use super::*; use crate::flat_absy::FlatVariable; - use crate::ir::{Function, LinComb, Prog, QuadComb, Statement}; + use crate::ir::{LinComb, Prog, QuadComb, Statement}; use crate::solvers::Solver; use zokrates_field::Bn128Field; @@ -73,13 +66,12 @@ mod tests { // (1 * ~one) * (42 * ~one) == 1 * ~out_0 // return ~out_0 - let _0 = FlatVariable::new(0); // unused var + let _0 = FlatParameter::private(FlatVariable::new(0)); // unused var let one = FlatVariable::one(); let out_0 = FlatVariable::public(0); - let main: Function = Function { - id: "main".to_string(), + let p: Prog = Prog { arguments: vec![_0], statements: vec![Statement::constraint( QuadComb::from_linear_combinations( @@ -91,11 +83,6 @@ mod tests { returns: vec![out_0], }; - let p: Prog = Prog { - private: vec![true], - main, - }; - let p = UnconstrainedVariableDetector::detect(p); assert!(p.is_err()); } @@ -106,21 +93,15 @@ mod tests { // (1 * ~one) * (1 * _0) == 1 * ~out_0 // return ~out_0 - let _0 = FlatVariable::new(0); + let _0 = FlatParameter::private(FlatVariable::new(0)); let out_0 = FlatVariable::public(0); - let main: Function = Function { - id: "main".to_string(), + let p: Prog = Prog { arguments: vec![_0], - statements: vec![Statement::definition(out_0, LinComb::from(_0))], + statements: vec![Statement::definition(out_0, LinComb::from(_0.id))], returns: vec![out_0], }; - let p: Prog = Prog { - private: vec![true], - main, - }; - let p = UnconstrainedVariableDetector::detect(p); assert!(p.is_ok()); } @@ -134,25 +115,24 @@ mod tests { // (1 * ~one) * (1 * ~one + (-1) * _1) == 1 * ~out_0 // return ~out_0 - let _0 = FlatVariable::new(0); + let _0 = FlatParameter::private(FlatVariable::new(0)); let _1 = FlatVariable::new(1); let _2 = FlatVariable::new(2); let out_0 = FlatVariable::public(0); let one = FlatVariable::one(); - let main: Function = Function { - id: "main".to_string(), + let p: Prog = Prog { arguments: vec![_0], statements: vec![ Statement::Directive(Directive { - inputs: vec![(LinComb::summand(-42, one) + LinComb::summand(1, _0)).into()], + inputs: vec![(LinComb::summand(-42, one) + LinComb::summand(1, _0.id)).into()], outputs: vec![_1, _2], solver: Solver::ConditionEq, }), Statement::constraint( QuadComb::from_linear_combinations( - LinComb::summand(-42, one) + LinComb::summand(1, _0), + LinComb::summand(-42, one) + LinComb::summand(1, _0.id), LinComb::summand(1, _2), ), LinComb::summand(1, _1), @@ -160,7 +140,7 @@ mod tests { Statement::constraint( QuadComb::from_linear_combinations( LinComb::summand(1, one) + LinComb::summand(-1, _1), - LinComb::summand(-42, one) + LinComb::summand(1, _0), + LinComb::summand(-42, one) + LinComb::summand(1, _0.id), ), LinComb::zero(), ), @@ -175,11 +155,6 @@ mod tests { returns: vec![out_0], }; - let p: Prog = Prog { - private: vec![true], - main, - }; - let p = UnconstrainedVariableDetector::detect(p); assert!(p.is_ok()); } diff --git a/zokrates_core/tests/wasm.rs b/zokrates_core/tests/wasm.rs index 7c7799aff..ba312d629 100644 --- a/zokrates_core/tests/wasm.rs +++ b/zokrates_core/tests/wasm.rs @@ -4,8 +4,8 @@ extern crate wasm_bindgen_test; extern crate zokrates_core; extern crate zokrates_field; use wasm_bindgen_test::*; -use zokrates_core::flat_absy::FlatVariable; -use zokrates_core::ir::{Function, Interpreter, Prog, Statement}; +use zokrates_core::flat_absy::{FlatParameter, FlatVariable}; +use zokrates_core::ir::{Interpreter, Prog, Statement}; use zokrates_core::proof_system::{Backend, NonUniversalBackend}; use zokrates_field::Bn128Field; @@ -15,16 +15,12 @@ use zokrates_core::proof_system::groth16::G16; #[wasm_bindgen_test] fn generate_proof() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::new(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::new(0), - )], - }, - private: vec![false], + arguments: vec![FlatParameter::public(FlatVariable::new(0))], + returns: vec![FlatVariable::new(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::new(0), + )], }; let interpreter = Interpreter::default(); From 34b631b64476f15120f7cc41699a435295d57c7a Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 30 Aug 2021 15:18:11 +0200 Subject: [PATCH 30/78] fix libsnark tests --- .../src/proof_system/libsnark/gm17.rs | 20 +++++++--------- .../src/proof_system/libsnark/mod.rs | 23 ++++++------------- .../src/proof_system/libsnark/pghr13.rs | 20 +++++++--------- 3 files changed, 23 insertions(+), 40 deletions(-) diff --git a/zokrates_core/src/proof_system/libsnark/gm17.rs b/zokrates_core/src/proof_system/libsnark/gm17.rs index a5d53dc09..af3c8b5ec 100644 --- a/zokrates_core/src/proof_system/libsnark/gm17.rs +++ b/zokrates_core/src/proof_system/libsnark/gm17.rs @@ -192,23 +192,19 @@ impl NonUniversalBackend for Libsnark { #[cfg(test)] mod tests { use super::*; - use crate::flat_absy::FlatVariable; - use crate::ir::{Function, Interpreter, Prog, Statement}; + use crate::flat_absy::{FlatParameter, FlatVariable}; + use crate::ir::{Interpreter, Prog, Statement}; use zokrates_field::Bn128Field; #[test] fn verify() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![true], + arguments: vec![FlatParameter::private(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let keypair = >::setup(program.clone()); diff --git a/zokrates_core/src/proof_system/libsnark/mod.rs b/zokrates_core/src/proof_system/libsnark/mod.rs index a2b52a9aa..480fe9a07 100644 --- a/zokrates_core/src/proof_system/libsnark/mod.rs +++ b/zokrates_core/src/proof_system/libsnark/mod.rs @@ -224,22 +224,13 @@ pub fn r1cs_program( let mut variables: HashMap = HashMap::new(); provide_variable_idx(&mut variables, &FlatVariable::one()); - for x in prog - .main - .arguments - .iter() - .enumerate() - .filter(|(index, _)| !prog.private[*index]) - { - provide_variable_idx(&mut variables, &x.1); + for x in prog.arguments.iter().filter(|p| !p.private) { + provide_variable_idx(&mut variables, &x.id); } - //Only the main function is relevant in this step, since all calls to other functions were resolved during flattening - let main = prog.main; - - //~out are added after main's arguments, since we want variables (columns) - //in the r1cs to be aligned like "public inputs | private inputs" - let main_return_count = main.returns.len(); + // ~out are added after main's arguments, since we want variables (columns) + // in the r1cs to be aligned like "public inputs | private inputs" + let main_return_count = prog.returns.len(); for i in 0..main_return_count { provide_variable_idx(&mut variables, &FlatVariable::public(i)); @@ -249,7 +240,7 @@ pub fn r1cs_program( let private_inputs_offset = variables.len(); // first pass through statements to populate `variables` - for (quad, lin) in main.statements.iter().filter_map(|s| match s { + for (quad, lin) in prog.statements.iter().filter_map(|s| match s { Statement::Constraint(quad, lin, _) => Some((quad, lin)), Statement::Directive(..) => None, }) { @@ -269,7 +260,7 @@ pub fn r1cs_program( let mut c = vec![]; // second pass to convert program to raw sparse vectors - for (quad, lin) in main.statements.into_iter().filter_map(|s| match s { + for (quad, lin) in prog.statements.into_iter().filter_map(|s| match s { Statement::Constraint(quad, lin, _) => Some((quad, lin)), Statement::Directive(..) => None, }) { diff --git a/zokrates_core/src/proof_system/libsnark/pghr13.rs b/zokrates_core/src/proof_system/libsnark/pghr13.rs index e78386c3f..075dbb27a 100644 --- a/zokrates_core/src/proof_system/libsnark/pghr13.rs +++ b/zokrates_core/src/proof_system/libsnark/pghr13.rs @@ -222,23 +222,19 @@ impl NonUniversalBackend for Libsnark { #[cfg(test)] mod tests { use super::*; - use crate::flat_absy::FlatVariable; - use crate::ir::{Function, Interpreter, Prog, Statement}; + use crate::flat_absy::{FlatParameter, FlatVariable}; + use crate::ir::{Interpreter, Prog, Statement}; use zokrates_field::Bn128Field; #[test] fn verify() { let program: Prog = Prog { - main: Function { - id: String::from("main"), - arguments: vec![FlatVariable::new(0)], - returns: vec![FlatVariable::public(0)], - statements: vec![Statement::constraint( - FlatVariable::new(0), - FlatVariable::public(0), - )], - }, - private: vec![true], + arguments: vec![FlatParameter::private(FlatVariable::new(0))], + returns: vec![FlatVariable::public(0)], + statements: vec![Statement::constraint( + FlatVariable::new(0), + FlatVariable::public(0), + )], }; let keypair = >::setup(program.clone()); From edc8c015c4aee230fdac18953af48fe0b860af3e Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 30 Aug 2021 18:00:48 +0200 Subject: [PATCH 31/78] add breaking test --- .../examples/empty_spread_propagation.zok | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 zokrates_cli/examples/empty_spread_propagation.zok diff --git a/zokrates_cli/examples/empty_spread_propagation.zok b/zokrates_cli/examples/empty_spread_propagation.zok new file mode 100644 index 000000000..025565625 --- /dev/null +++ b/zokrates_cli/examples/empty_spread_propagation.zok @@ -0,0 +1,16 @@ +def func() -> bool: + for u32 i in 0..N do + endfor + + u64[N] y = [...[0; N-1], 1] // the rhs should *not* be reduced to [1] because the spread is not empty + u64 q = 0 + + for u32 i in 0..N do + q = y[i] + endfor + + return true + +def main(): + assert(func::<2>()) + return \ No newline at end of file From 8169e1839a671c02487ec2cee69e744e8cfdbe9e Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 30 Aug 2021 18:01:35 +0200 Subject: [PATCH 32/78] fix rule by using PartialEq --- zokrates_core/src/static_analysis/propagation.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 6cab5c7c7..38c29a0cb 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1078,7 +1078,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }) // ignore spreads over empty arrays .filter_map(|e| match e { - TypedExpressionOrSpread::Spread(s) if s.array.size() == 0 => None, + TypedExpressionOrSpread::Spread(s) + if s.array.size() == UExpression::from(0u32) => + { + None + } e => Some(e), }) .collect(), From 469664b06974876e043349f7490740045374d502 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 30 Aug 2021 18:04:00 +0200 Subject: [PATCH 33/78] changelog --- changelogs/unreleased/987-schaeff | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/987-schaeff diff --git a/changelogs/unreleased/987-schaeff b/changelogs/unreleased/987-schaeff new file mode 100644 index 000000000..8fe0d3fc6 --- /dev/null +++ b/changelogs/unreleased/987-schaeff @@ -0,0 +1 @@ +Fix incorrect propagation of spreads \ No newline at end of file From 97a034cdb983225c0a7016fa82c2fb7a6e25cf65 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sat, 4 Sep 2021 23:36:14 +0200 Subject: [PATCH 34/78] implement imported constants, implement inference on generic structs --- .../examples/array_generic_inference.zok | 12 + .../examples/struct_generic_inference.zok | 16 + zokrates_core/src/compile.rs | 2 + zokrates_core/src/semantics.rs | 107 ++--- .../src/static_analysis/constant_inliner.rs | 387 ++---------------- .../src/static_analysis/reducer/mod.rs | 104 ++++- zokrates_core/src/typed_absy/abi.rs | 6 +- zokrates_core/src/typed_absy/identifier.rs | 6 + zokrates_core/src/typed_absy/result_folder.rs | 12 +- zokrates_core/src/typed_absy/types.rs | 11 +- 10 files changed, 243 insertions(+), 420 deletions(-) create mode 100644 zokrates_cli/examples/array_generic_inference.zok create mode 100644 zokrates_cli/examples/struct_generic_inference.zok diff --git a/zokrates_cli/examples/array_generic_inference.zok b/zokrates_cli/examples/array_generic_inference.zok new file mode 100644 index 000000000..c0cee5776 --- /dev/null +++ b/zokrates_cli/examples/array_generic_inference.zok @@ -0,0 +1,12 @@ +def myFct(u64[N] ignored) -> u64[N2]: + assert(2*N == N2) + + return [0; N2] + + +const u32 N = 3 +const u32 N2 = 2*N +def main(u64[N] arg) -> bool: + u64[N2] someVariable = myFct(arg) + + return true \ No newline at end of file diff --git a/zokrates_cli/examples/struct_generic_inference.zok b/zokrates_cli/examples/struct_generic_inference.zok new file mode 100644 index 000000000..07dbb53bb --- /dev/null +++ b/zokrates_cli/examples/struct_generic_inference.zok @@ -0,0 +1,16 @@ +struct SomeStruct { + u64[N] f +} + +def myFct(SomeStruct ignored) -> u32[N2]: + assert(2*N == N2) + + return [N3; N2] + + +const u32 N = 3 +const u32 N2 = 2*N +def main(SomeStruct arg) -> u32: + u32[N2] someVariable = myFct::<_, _, 42>(arg) + + return someVariable[0] \ No newline at end of file diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 172ad7473..de33ea26f 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -247,6 +247,8 @@ fn check_with_arena<'ast, T: Field, E: Into>( let typed_ast = Checker::check(compiled) .map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?; + log::trace!("\n{}", typed_ast); + let main_module = typed_ast.main.clone(); log::debug!("Run static analysis"); diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 9353ce7de..ef2237082 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -739,7 +739,10 @@ impl<'ast, T: Field> Checker<'ast, T> { let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into()); constants.push((id.clone(), TypedConstantSymbol::There(imported_id))); - self.insert_into_scope(Variable::with_id_and_type(declaration.id, crate::typed_absy::types::try_from_g_type(ty.clone()).unwrap())); + self.insert_into_scope(Variable::with_id_and_type(CoreIdentifier::Constant(CanonicalConstantIdentifier::new( + declaration.id, + module_id.into(), + )), crate::typed_absy::types::try_from_g_type(ty.clone()).unwrap())); state .constants @@ -853,7 +856,6 @@ impl<'ast, T: Field> Checker<'ast, T> { // we go through symbol declarations and check them for declaration in module.symbols { - println!("{:#?}", self.scope); self.check_symbol_declaration( declaration, module_id, @@ -970,6 +972,7 @@ impl<'ast, T: Field> Checker<'ast, T> { }); } }; + arguments_checked.push(DeclarationParameter { id: decl_v, private: arg.private, @@ -3328,7 +3331,7 @@ mod tests { fn field_in_range() { // The value of `P - 1` is a valid field literal let expr = Expression::FieldConstant(Bn128Field::max_value().to_biguint()).mock(); - assert!(Checker::::new() + assert!(Checker::::default() .check_expression(expr, &*MODULE_ID, &TypeMap::new()) .is_ok()); } @@ -3339,7 +3342,7 @@ mod tests { let value = Bn128Field::max_value().to_biguint().add(1u32); let expr = Expression::FieldConstant(value).mock(); - assert!(Checker::::new() + assert!(Checker::::default() .check_expression(expr, &*MODULE_ID, &TypeMap::new()) .is_err()); } @@ -3362,7 +3365,7 @@ mod tests { Expression::BooleanConstant(true).mock().into(), ]) .mock(); - assert!(Checker::::new() + assert!(Checker::::default() .check_expression(a, &*MODULE_ID, &types) .is_err()); @@ -3382,7 +3385,7 @@ mod tests { .into(), ]) .mock(); - assert!(Checker::::new() + assert!(Checker::::default() .check_expression(a, &*MODULE_ID, &types) .is_ok()); @@ -3398,7 +3401,7 @@ mod tests { .into(), ]) .mock(); - assert!(Checker::::new() + assert!(Checker::::default() .check_expression(a, &*MODULE_ID, &types) .is_err()); } @@ -3480,7 +3483,7 @@ mod tests { fn unifier() { // the unifier should only accept either a single type or many functions of different signatures for each symbol - let mut unifier = SymbolUnifier::default(); + let mut unifier = SymbolUnifier::::default(); // the `foo` type assert!(unifier.insert_type("foo")); @@ -3567,7 +3570,7 @@ mod tests { .collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&OwnedTypedModuleId::from("bar"), &mut state), @@ -3619,7 +3622,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner @@ -3695,7 +3698,7 @@ mod tests { let mut state = State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); } @@ -3733,7 +3736,7 @@ mod tests { let mut state = State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); } @@ -3785,7 +3788,7 @@ mod tests { let mut state = State::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner @@ -3823,7 +3826,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!(checker.check_module(&*MODULE_ID, &mut state), Ok(())); assert!(state .typed_modules @@ -3872,7 +3875,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner @@ -3914,7 +3917,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner @@ -3965,7 +3968,7 @@ mod tests { .collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner @@ -4013,7 +4016,7 @@ mod tests { .collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state).unwrap_err()[0] .inner @@ -4051,7 +4054,7 @@ mod tests { ) .mock()]); assert_eq!( - Checker::::new().check_signature(signature, &*MODULE_ID, &state), + Checker::::default().check_signature(signature, &*MODULE_ID, &state), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Undeclared symbol `K`".to_string() @@ -4086,7 +4089,7 @@ mod tests { ) .mock()]); assert_eq!( - Checker::::new().check_signature(signature, &*MODULE_ID, &state), + Checker::::default().check_signature(signature, &*MODULE_ID, &state), Ok(DeclarationSignature::new() .inputs(vec![DeclarationType::array(( DeclarationType::array(( @@ -4116,7 +4119,7 @@ mod tests { ) .mock(); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); checker.enter_scope(); assert_eq!( @@ -4233,7 +4236,7 @@ mod tests { let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_module(&*MODULE_ID, &mut state), Err(vec![Error { @@ -4350,7 +4353,7 @@ mod tests { let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); } @@ -4389,7 +4392,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_function(foo, &*MODULE_ID, &state), Err(vec![ErrorInner { @@ -4473,7 +4476,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_function(foo, &*MODULE_ID, &state), Ok(foo_checked) @@ -4527,7 +4530,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); + let mut checker: Checker = new_with_args(HashMap::new(), 0, functions); assert_eq!( checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { @@ -4586,7 +4589,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); + let mut checker: Checker = new_with_args(HashMap::new(), 0, functions); assert_eq!( checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { @@ -4632,7 +4635,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert_eq!( checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { @@ -4735,7 +4738,7 @@ mod tests { let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert_eq!( checker.check_module(&*MODULE_ID, &mut state), Err(vec![Error { @@ -4821,7 +4824,7 @@ mod tests { let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert_eq!( checker.check_module(&*MODULE_ID, &mut state), Err(vec![ @@ -4936,7 +4939,7 @@ mod tests { let mut state = State::::new(vec![((*MODULE_ID).clone(), module)].into_iter().collect()); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert!(checker.check_module(&*MODULE_ID, &mut state).is_ok()); } @@ -4975,7 +4978,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert_eq!( checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { @@ -5016,7 +5019,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert_eq!( checker.check_function(bar, &*MODULE_ID, &state), Err(vec![ErrorInner { @@ -5124,7 +5127,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); + let mut checker: Checker = new_with_args(HashMap::new(), 0, functions); assert_eq!( checker.check_function(bar, &*MODULE_ID, &state), Ok(bar_checked) @@ -5157,7 +5160,7 @@ mod tests { let modules = Modules::new(); let state = State::new(modules); - let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashMap::new(), 0, HashSet::new()); assert_eq!( checker .check_function(f, &*MODULE_ID, &state) @@ -5239,7 +5242,7 @@ mod tests { main: (*MODULE_ID).clone(), }; - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); assert_eq!( checker.check_program(program), Err(vec![Error { @@ -5259,7 +5262,7 @@ mod tests { // // should fail - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), @@ -5293,7 +5296,7 @@ mod tests { // // should fail - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), @@ -5340,7 +5343,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); checker.check_module(&*MODULE_ID, &mut state).unwrap(); @@ -5371,7 +5374,7 @@ mod tests { )); assert_eq!( - Checker::::new().check_struct_type_declaration( + Checker::::default().check_struct_type_declaration( "Foo".into(), declaration, &*MODULE_ID, @@ -5415,7 +5418,7 @@ mod tests { )); assert_eq!( - Checker::::new().check_struct_type_declaration( + Checker::::default().check_struct_type_declaration( "Foo".into(), declaration, &*MODULE_ID, @@ -5449,7 +5452,7 @@ mod tests { .mock(); assert_eq!( - Checker::::new() + Checker::::default() .check_struct_type_declaration( "Foo".into(), declaration, @@ -5508,7 +5511,9 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - assert!(Checker::new().check_module(&*MODULE_ID, &mut state).is_ok()); + assert!(Checker::default() + .check_module(&*MODULE_ID, &mut state) + .is_ok()); assert_eq!( state .types @@ -5564,7 +5569,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - assert!(Checker::new() + assert!(Checker::default() .check_module(&*MODULE_ID, &mut state) .is_err()); } @@ -5597,7 +5602,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - assert!(Checker::new() + assert!(Checker::default() .check_module(&*MODULE_ID, &mut state) .is_err()); } @@ -5648,7 +5653,7 @@ mod tests { vec![((*MODULE_ID).clone(), module)].into_iter().collect(), ); - assert!(Checker::new() + assert!(Checker::default() .check_module(&*MODULE_ID, &mut state) .is_err()); } @@ -6139,7 +6144,9 @@ mod tests { modules: vec![("".into(), m)].into_iter().collect(), }; - let errors = Checker::::new().check_program(p).unwrap_err(); + let errors = Checker::::default() + .check_program(p) + .unwrap_err(); assert_eq!(errors.len(), 1); @@ -6158,7 +6165,7 @@ mod tests { // a = 42 let a = Assignee::Identifier("a").mock(); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); checker.enter_scope(); checker @@ -6190,7 +6197,7 @@ mod tests { ) .mock(); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); checker.enter_scope(); checker @@ -6240,7 +6247,7 @@ mod tests { ) .mock(); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::default(); checker.enter_scope(); checker diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index 2ba585f49..bb023ec49 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -149,124 +149,8 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { .collect::, _>>()?, functions: m .functions - .into_iter() - .map(|(key, fun)| { - Ok(( - self.fold_declaration_function_key(key)?, - self.fold_function_symbol(fun)?, - )) - }) - .collect::, _>>()? - .into_iter() - .collect(), }) } - - // fn fold_declaration_constant( - // &mut self, - // c: DeclarationConstant<'ast, T>, - // ) -> Result, Self::Error> { - // match c { - // // replace constants by their concrete value in declaration types - // DeclarationConstant::Constant(id) => { - // let id = CanonicalConstantIdentifier { - // module: self.fold_module_id(id.module)?, - // ..id - // }; - - // match self.get_constant(&id).unwrap() { - // TypedConstant { - // ty: DeclarationType::Uint(UBitwidth::B32), - // expression - // } => Ok(DeclarationConstant::Expression(expression)), - // c => Err(Error::Propagation(format!("Failed to reduce `{}` to a single u32 literal, try avoiding function calls in the definition of `{}` in {}", c, id.id, id.module.display()))) - // } - // } - // c => Ok(c), - // } - // } - - // fn fold_field_expression( - // &mut self, - // e: FieldElementExpression<'ast, T>, - // ) -> Result, Self::Error> { - // match e { - // FieldElementExpression::Identifier(ref id) => { - // match self.get_constant_for_identifier(id) { - // Some(c) => Ok(c.try_into().unwrap()), - // None => fold_field_expression(self, e), - // } - // } - // e => fold_field_expression(self, e), - // } - // } - - // fn fold_boolean_expression( - // &mut self, - // e: BooleanExpression<'ast, T>, - // ) -> Result, Self::Error> { - // match e { - // BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) { - // Some(c) => Ok(c.try_into().unwrap()), - // None => fold_boolean_expression(self, e), - // }, - // e => fold_boolean_expression(self, e), - // } - // } - - // fn fold_uint_expression_inner( - // &mut self, - // size: UBitwidth, - // e: UExpressionInner<'ast, T>, - // ) -> Result, Self::Error> { - // match e { - // UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) { - // Some(c) => { - // let e: UExpression<'ast, T> = c.try_into().unwrap(); - // Ok(e.into_inner()) - // } - // None => fold_uint_expression_inner(self, size, e), - // }, - // e => fold_uint_expression_inner(self, size, e), - // } - // } - - // fn fold_array_expression_inner( - // &mut self, - // ty: &ArrayType<'ast, T>, - // e: ArrayExpressionInner<'ast, T>, - // ) -> Result, Self::Error> { - // match e { - // ArrayExpressionInner::Identifier(ref id) => { - // match self.get_constant_for_identifier(id) { - // Some(c) => { - // let e: ArrayExpression<'ast, T> = c.try_into().unwrap(); - // Ok(e.into_inner()) - // } - // None => fold_array_expression_inner(self, ty, e), - // } - // } - // e => fold_array_expression_inner(self, ty, e), - // } - // } - - // fn fold_struct_expression_inner( - // &mut self, - // ty: &StructType<'ast, T>, - // e: StructExpressionInner<'ast, T>, - // ) -> Result, Self::Error> { - // match e { - // StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) - // { - // Some(c) => { - // let e: StructExpression<'ast, T> = c.try_into().unwrap(); - // Ok(e.into_inner()) - // } - // None => fold_struct_expression_inner(self, ty, e), - // }, - // e => fold_struct_expression_inner(self, ty, e), - // } - // } } #[cfg(test)] @@ -330,39 +214,9 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); - - let expected_main = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }; + let expected_program = program.clone(); - let expected_program: TypedProgram = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants, - }, - )] - .into_iter() - .collect(), - }; + let program = ConstantInliner::inline(program); assert_eq!(program, Ok(expected_program)) } @@ -372,13 +226,13 @@ mod tests { // const bool a = true // // def main() -> bool: - // return a + // return main.zok/a - let const_id = "a"; + let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( - Identifier::from(const_id), + Identifier::from(const_id.clone()), ) .into()])], signature: DeclarationSignature::new() @@ -387,7 +241,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new(const_id, "main".into()), + const_id, TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Boolean(BooleanExpression::Value(true)), DeclarationType::Boolean, @@ -418,39 +272,9 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); + let expected_program = program.clone(); - let expected_main = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - BooleanExpression::Value(true).into() - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Boolean]), - }; - - let expected_program: TypedProgram = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Boolean]), - ), - TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants, - }, - )] - .into_iter() - .collect(), - }; + let program = ConstantInliner::inline(program); assert_eq!(program, Ok(expected_program)) } @@ -462,11 +286,11 @@ mod tests { // def main() -> u32: // return a - let const_id = "a"; + let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier( - Identifier::from(const_id), + Identifier::from(const_id.clone()), ) .annotate(UBitwidth::B32) .into()])], @@ -476,7 +300,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new(const_id, "main".into()), + const_id, TypedConstantSymbol::Here(TypedConstant::new( UExpressionInner::Value(1u128) .annotate(UBitwidth::B32) @@ -509,39 +333,9 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); - - let expected_main = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128) - .annotate(UBitwidth::B32) - .into()])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), - }; + let expected_program = program.clone(); - let expected_program: TypedProgram = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), - ), - TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants, - }, - )] - .into_iter() - .collect(), - }; + let program = ConstantInliner::inline(program); assert_eq!(program, Ok(expected_program)) } @@ -553,18 +347,18 @@ mod tests { // def main() -> field: // return a[0] + a[1] - let const_id = "a"; + let const_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( FieldElementExpression::select( - ArrayExpressionInner::Identifier(Identifier::from(const_id)) + ArrayExpressionInner::Identifier(Identifier::from(const_id.clone())) .annotate(GType::FieldElement, 2usize), UExpressionInner::Value(0u128).annotate(UBitwidth::B32), ) .into(), FieldElementExpression::select( - ArrayExpressionInner::Identifier(Identifier::from(const_id)) + ArrayExpressionInner::Identifier(Identifier::from(const_id.clone())) .annotate(GType::FieldElement, 2usize), UExpressionInner::Value(1u128).annotate(UBitwidth::B32), ) @@ -577,7 +371,7 @@ mod tests { }; let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new(const_id, "main".into()), + const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::Array( ArrayExpressionInner::Value( @@ -620,63 +414,9 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); + let expected_program = program.clone(); - let expected_main = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( - FieldElementExpression::select( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ] - .into(), - ) - .annotate(GType::FieldElement, 2usize), - UExpressionInner::Value(0u128).annotate(UBitwidth::B32), - ) - .into(), - FieldElementExpression::select( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ] - .into(), - ) - .annotate(GType::FieldElement, 2usize), - UExpressionInner::Value(1u128).annotate(UBitwidth::B32), - ) - .into(), - ) - .into()])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }; - - let expected_program: TypedProgram = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants, - }, - )] - .into_iter() - .collect(), - }; + let program = ConstantInliner::inline(program); assert_eq!(program, Ok(expected_program)) } @@ -689,13 +429,13 @@ mod tests { // def main() -> field: // return b - let const_a_id = "a"; - let const_b_id = "b"; + let const_a_id = CanonicalConstantIdentifier::new("a", "main".into()); + let const_b_id = CanonicalConstantIdentifier::new("a", "main".into()); let main: TypedFunction = TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(), + FieldElementExpression::Identifier(Identifier::from(const_b_id.clone())).into(), ])], signature: DeclarationSignature::new() .inputs(vec![]) @@ -719,7 +459,7 @@ mod tests { .collect(), constants: vec![ ( - CanonicalConstantIdentifier::new(const_a_id, "main".into()), + const_a_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(1), @@ -728,11 +468,11 @@ mod tests { )), ), ( - CanonicalConstantIdentifier::new(const_b_id, "main".into()), + const_b_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Add( box FieldElementExpression::Identifier(Identifier::from( - const_a_id, + const_a_id.clone(), )), box FieldElementExpression::Number(Bn128Field::from(1)), )), @@ -748,60 +488,9 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); + let expected_program = program.clone(); - let expected_main = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }; - - let expected_program: TypedProgram = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants: vec![ - ( - CanonicalConstantIdentifier::new(const_a_id, "main".into()), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(1), - )), - DeclarationType::FieldElement, - )), - ), - ( - CanonicalConstantIdentifier::new(const_b_id, "main".into()), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(2), - )), - DeclarationType::FieldElement, - )), - ), - ] - .into_iter() - .collect(), - }, - )] - .into_iter() - .collect(), - }; + let program = ConstantInliner::inline(program); assert_eq!(program, Ok(expected_program)) } @@ -824,10 +513,10 @@ mod tests { // def main() -> field: // return FOO - let foo_const_id = "FOO"; + let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into()); let foo_module = TypedModule { functions: vec![( - DeclarationFunctionKey::with_location("main", "main") + DeclarationFunctionKey::with_location("foo", "main") .signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], @@ -838,7 +527,7 @@ mod tests { .into_iter() .collect(), constants: vec![( - CanonicalConstantIdentifier::new(foo_const_id, "foo".into()), + foo_const_id.clone(), TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(42), @@ -850,6 +539,7 @@ mod tests { .collect(), }; + let main_const_id = CanonicalConstantIdentifier::new("FOO", "main".into()); let main_module = TypedModule { functions: vec![( DeclarationFunctionKey::with_location("main", "main").signature( @@ -860,7 +550,8 @@ mod tests { TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(), + FieldElementExpression::Identifier(Identifier::from(main_const_id.clone())) + .into(), ])], signature: DeclarationSignature::new() .inputs(vec![]) @@ -870,11 +561,8 @@ mod tests { .into_iter() .collect(), constants: vec![( - CanonicalConstantIdentifier::new(foo_const_id, "main".into()), - TypedConstantSymbol::There(CanonicalConstantIdentifier::new( - foo_const_id, - "foo".into(), - )), + main_const_id.clone(), + TypedConstantSymbol::There(foo_const_id), )] .into_iter() .collect(), @@ -901,7 +589,8 @@ mod tests { TypedFunctionSymbol::Here(TypedFunction { arguments: vec![], statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(42)).into(), + FieldElementExpression::Identifier(Identifier::from(main_const_id.clone())) + .into(), ])], signature: DeclarationSignature::new() .inputs(vec![]) @@ -911,7 +600,7 @@ mod tests { .into_iter() .collect(), constants: vec![( - CanonicalConstantIdentifier::new(foo_const_id, "main".into()), + main_const_id, TypedConstantSymbol::Here(TypedConstant::new( TypedExpression::FieldElement(FieldElementExpression::Number( Bn128Field::from(42), diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 81ebca4cb..62fbced18 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -24,9 +24,10 @@ use crate::typed_absy::UBitwidth; use std::collections::HashMap; use crate::typed_absy::{ - ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, DeclarationConstant, - DeclarationSignature, Expr, FieldElementExpression, FunctionCall, FunctionCallExpression, - FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, TypedConstant, + ArrayExpression, ArrayExpressionInner, ArrayType, BlockExpression, BooleanExpression, + CoreIdentifier, DeclarationConstant, DeclarationSignature, Expr, FieldElementExpression, + FunctionCall, FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, + OwnedTypedModuleId, StructExpression, StructExpressionInner, StructType, TypedConstant, TypedConstantSymbol, TypedExpression, TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable, @@ -69,7 +70,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { &mut self, e: FieldElementExpression<'ast, T>, ) -> Result, Self::Error> { - match dbg!(e) { + match e { FieldElementExpression::Identifier(Identifier { id: CoreIdentifier::Constant(c), version, @@ -81,12 +82,28 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { } } + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + BooleanExpression::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + Ok(self.constants.get(&c).cloned().unwrap().try_into().unwrap()) + } + e => fold_boolean_expression(self, e), + } + } + fn fold_uint_expression_inner( &mut self, ty: UBitwidth, e: UExpressionInner<'ast, T>, ) -> Result, Self::Error> { - match dbg!(e) { + match e { UExpressionInner::Identifier(Identifier { id: CoreIdentifier::Constant(c), version, @@ -102,6 +119,48 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { } } + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + ArrayExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + Ok( + ArrayExpression::try_from(self.constants.get(&c).cloned().unwrap()) + .unwrap() + .into_inner(), + ) + } + e => fold_array_expression_inner(self, ty, e), + } + } + + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + StructExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + Ok( + StructExpression::try_from(self.constants.get(&c).cloned().unwrap()) + .unwrap() + .into_inner(), + ) + } + e => fold_struct_expression_inner(self, ty, e), + } + } + fn fold_declaration_constant( &mut self, c: DeclarationConstant<'ast, T>, @@ -132,9 +191,17 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { .into_iter() .map(|(key, tc)| match tc { TypedConstantSymbol::Here(c) => { + let c = self.fold_constant(c)?; + + // replace the existing constants in this expression + let constant_replaced_expression = self.fold_expression(c.expression)?; + + // wrap this expression in a function let wrapper = TypedFunction { arguments: vec![], - statements: vec![TypedStatement::Return(vec![c.expression])], + statements: vec![TypedStatement::Return(vec![ + constant_replaced_expression, + ])], signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]), }; @@ -153,6 +220,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { { assert_eq!(expressions.len(), 1); let constant_expression = expressions.pop().unwrap(); + use crate::typed_absy::Constant; if !constant_expression.is_constant() { return Err(Error::ConstantReduction( @@ -162,6 +230,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { }; self.constants .insert(key.clone(), constant_expression.clone()); + Ok(( key, TypedConstantSymbol::Here(TypedConstant { @@ -170,7 +239,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { }), )) } else { - Err(Error::ConstantReduction(key.id.to_string(), key.module)); + Err(Error::ConstantReduction(key.id.to_string(), key.module)) } } _ => unreachable!("all constants should be local"), @@ -179,7 +248,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { functions: m .functions .into_iter() - .map(|(key, fun)| self.fold_function_symbol(fun).map(|f| (key, f))) + .map(|(key, fun)| { + let key = self.fold_declaration_function_key(key)?; + let fun = self.fold_function_symbol(fun)?; + Ok((key, fun)) + }) .collect::>()?, }) } @@ -465,8 +538,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { - println!("STAT {}", s); - let res = match s { TypedStatement::MultipleDefinition( v, @@ -647,6 +718,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { pub fn reduce_program(p: TypedProgram) -> Result, Error> { // inline all constants and replace them in the program + let mut constant_calls_inliner = ConstantCallsInliner::with_program(p.clone()); let p = constant_calls_inliner.fold_program(p)?; @@ -698,7 +770,9 @@ fn reduce_function<'ast, T: Field>( ) -> Result, Error> { let mut versions = Versions::default(); - match ShallowTransformer::transform(f, &generics, &mut versions) { + let mut constants = Constants::default(); + + let f = match ShallowTransformer::transform(f, &generics, &mut versions) { Output::Complete(f) => Ok(f), Output::Incomplete(new_f, new_for_loop_versions) => { let mut for_loop_versions = new_for_loop_versions; @@ -707,8 +781,6 @@ fn reduce_function<'ast, T: Field>( let mut substitutions = Substitutions::default(); - let mut constants = Constants::default(); - let mut hash = None; loop { @@ -765,7 +837,11 @@ fn reduce_function<'ast, T: Field>( } } } - } + }?; + + Propagator::with_constants(&mut constants) + .fold_function(f) + .map_err(|e| Error::Incompatible(format!("{}", e))) } fn compute_hash(f: &TypedFunction) -> u64 { diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 15554a735..6256e10ab 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -37,12 +37,12 @@ mod tests { parameter::DeclarationParameter, variable::DeclarationVariable, ConcreteType, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, }; - use std::collections::HashMap; + use std::collections::BTreeMap; use zokrates_field::Bn128Field; #[test] fn generate_abi_from_typed_ast() { - let mut functions = HashMap::new(); + let mut functions = BTreeMap::new(); functions.insert( ConcreteFunctionKey::with_location("main", "main").into(), TypedFunctionSymbol::Here(TypedFunction { @@ -64,7 +64,7 @@ mod tests { }), ); - let mut modules = HashMap::new(); + let mut modules = BTreeMap::new(); modules.insert( "main".into(), TypedModule { diff --git a/zokrates_core/src/typed_absy/identifier.rs b/zokrates_core/src/typed_absy/identifier.rs index d60226713..972a69e66 100644 --- a/zokrates_core/src/typed_absy/identifier.rs +++ b/zokrates_core/src/typed_absy/identifier.rs @@ -61,6 +61,12 @@ impl<'ast> fmt::Display for Identifier<'ast> { } } +impl<'ast> From> for Identifier<'ast> { + fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> { + Identifier::from(CoreIdentifier::Constant(id)) + } +} + impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { Identifier::from(CoreIdentifier::Source(id)) diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 218f53867..fb38b2f70 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -1007,7 +1007,11 @@ pub fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>( s: DeclarationSignature<'ast, T>, ) -> Result, F::Error> { Ok(DeclarationSignature { - generics: s.generics, + generics: s + .generics + .into_iter() + .map(|g| g.map(|g| f.fold_declaration_constant(g)).transpose()) + .collect::>()?, inputs: s .inputs .into_iter() @@ -1156,7 +1160,11 @@ pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( functions: m .functions .into_iter() - .map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f))) + .map(|(key, fun)| { + let key = f.fold_declaration_function_key(key)?; + let fun = f.fold_function_symbol(fun)?; + Ok((key, fun)) + }) .collect::>()?, }) } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 7520b4b79..0ff64a6d6 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -194,7 +194,8 @@ impl<'ast, T> From> for UExpression<'ast, T> { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) } DeclarationConstant::Constant(v) => { - UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32) + UExpressionInner::Identifier(CoreIdentifier::from(v).into()) + .annotate(UBitwidth::B32) } DeclarationConstant::Expression(e) => e.try_into().unwrap(), } @@ -968,6 +969,11 @@ pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq>( (DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1, (DeclarationType::Struct(s0), GType::Struct(s1)) => { s0.canonical_location == s1.canonical_location + && s0 + .members + .iter() + .zip(s1.members.iter()) + .all(|(m0, m1)| check_type(&*m0.ty, &*m1.ty, constants)) } _ => false, } @@ -1347,6 +1353,7 @@ pub mod signature { #[cfg(test)] mod tests { use super::*; + use zokrates_field::Bn128Field; #[test] fn signature() { @@ -1363,7 +1370,7 @@ pub mod signature { //

(field[P]) // (field[Q]) - let generic1 = DeclarationSignature::new() + let generic1 = DeclarationSignature::::new() .generics(vec![Some( GenericIdentifier { name: "P", From 7c248fd77dc1dc23c52ea4d3eef7b9fc4b491ab7 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Sep 2021 12:53:08 +0200 Subject: [PATCH 35/78] pass prog as reference, fix tests --- zokrates_core/src/static_analysis/mod.rs | 3 +- .../src/static_analysis/unconstrained_vars.rs | 29 ++++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index c2c32b536..1a9a3236c 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -150,6 +150,7 @@ impl Analyse for Prog { fn analyse(self) -> Result { log::debug!("Static analyser: Detect unconstrained zir"); - UnconstrainedVariableDetector::detect(self).map_err(Error::from) + UnconstrainedVariableDetector::detect(&self).map_err(Error::from)?; + Ok(self) } } diff --git a/zokrates_core/src/static_analysis/unconstrained_vars.rs b/zokrates_core/src/static_analysis/unconstrained_vars.rs index c285a4e47..dcade6835 100644 --- a/zokrates_core/src/static_analysis/unconstrained_vars.rs +++ b/zokrates_core/src/static_analysis/unconstrained_vars.rs @@ -11,7 +11,7 @@ pub struct UnconstrainedVariableDetector { pub(self) variables: HashSet, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct Error(usize); impl fmt::Display for Error { @@ -26,12 +26,12 @@ impl fmt::Display for Error { } impl UnconstrainedVariableDetector { - pub fn detect(p: Prog) -> Result, Error> { + pub fn detect(p: &Prog) -> Result<(), Error> { let mut instance = Self::default(); instance.visit_module(&p); if instance.variables.is_empty() { - Ok(p) + Ok(()) } else { Err(Error(instance.variables.len())) } @@ -61,12 +61,12 @@ mod tests { use zokrates_field::Bn128Field; #[test] - fn should_detect_unconstrained_private_input() { + fn unconstrained_private_input() { // def main(_0) -> (1): // (1 * ~one) * (42 * ~one) == 1 * ~out_0 // return ~out_0 - let _0 = FlatParameter::private(FlatVariable::new(0)); // unused var + let _0 = FlatParameter::private(FlatVariable::new(0)); // unused private parameter let one = FlatVariable::one(); let out_0 = FlatVariable::public(0); @@ -83,12 +83,15 @@ mod tests { returns: vec![out_0], }; - let p = UnconstrainedVariableDetector::detect(p); - assert!(p.is_err()); + let result = UnconstrainedVariableDetector::detect(&p); + assert_eq!( + result.expect_err("expected an error").to_string(), + "Found unconstrained variables during IR analysis (found 1 occurrence)" + ); } #[test] - fn should_pass_with_constrained_private_input() { + fn constrained_private_input() { // def main(_0) -> (1): // (1 * ~one) * (1 * _0) == 1 * ~out_0 // return ~out_0 @@ -102,12 +105,12 @@ mod tests { returns: vec![out_0], }; - let p = UnconstrainedVariableDetector::detect(p); - assert!(p.is_ok()); + let result = UnconstrainedVariableDetector::detect(&p); + assert_eq!(result, Ok(())); } #[test] - fn should_pass_with_directive() { + fn constrained_directive() { // def main(_0) -> (1): // # _1, _2 = ConditionEq((-42) * ~one + 1 * _0) // ((-42) * ~one + 1 * _0) * (1 * _2) == 1 * _1 @@ -155,7 +158,7 @@ mod tests { returns: vec![out_0], }; - let p = UnconstrainedVariableDetector::detect(p); - assert!(p.is_ok()); + let result = UnconstrainedVariableDetector::detect(&p); + assert_eq!(result, Ok(())); } } From 0e57c59542d6fa308d2c81af99b3b666c2b0ae3a Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 6 Sep 2021 16:36:36 +0200 Subject: [PATCH 36/78] remove propagation from constant inliner, simplify --- .../src/static_analysis/constant_inliner.rs | 119 +++++++----------- zokrates_core/src/static_analysis/mod.rs | 10 +- .../src/static_analysis/reducer/mod.rs | 4 - zokrates_core/src/typed_absy/types.rs | 6 + 4 files changed, 49 insertions(+), 90 deletions(-) diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs index bb023ec49..7e38b9f43 100644 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ b/zokrates_core/src/static_analysis/constant_inliner.rs @@ -1,30 +1,16 @@ -use crate::static_analysis::Propagator; -use crate::typed_absy::result_folder::*; +// Static analysis step to replace all imported constants with the imported value +// This does *not* reduce constants to their literal value +// This step cannot fail as the imports were checked during semantics + +use crate::typed_absy::folder::*; use crate::typed_absy::*; use std::collections::HashMap; -use std::fmt; use zokrates_field::Field; -// a map of the constants in this program -// the values are constants whose expression does not include any identifier. It does not have to be a single literal, as -// we keep function calls here to be inlined later +// a map of the canonical constants in this program. with all imported constants reduced to their canonical value type ProgramConstants<'ast, T> = HashMap, TypedConstant<'ast, T>>>; -#[derive(Debug, PartialEq)] -pub enum Error { - Type(String), - Propagation(String), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Type(s) => write!(f, "{}", s), - Error::Propagation(s) => write!(f, "{}", s), - } - } -} pub struct ConstantInliner<'ast, T> { modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, @@ -43,7 +29,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { constants, } } - pub fn inline(p: TypedProgram<'ast, T>) -> Result, Error> { + pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { let constants = ProgramConstants::new(); let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants); inliner.fold_program(p) @@ -71,85 +57,64 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { } } -impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantInliner<'ast, T> { - type Error = Error; - - fn fold_program( - &mut self, - p: TypedProgram<'ast, T>, - ) -> Result, Self::Error> { - self.fold_module_id(p.main.clone())?; +impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { + fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + self.fold_module_id(p.main.clone()); - Ok(TypedProgram { + TypedProgram { modules: std::mem::take(&mut self.modules), ..p - }) + } } - fn fold_module_id( - &mut self, - id: OwnedTypedModuleId, - ) -> Result { + fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); let m = self.modules.remove(&id).unwrap(); - let m = self.fold_module(m)?; + let m = self.fold_module(m); self.modules.insert(id.clone(), m); self.change_location(current_m_id); } - Ok(id) + id } - fn fold_module( - &mut self, - m: TypedModule<'ast, T>, - ) -> Result, Self::Error> { - Ok(TypedModule { + fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { + TypedModule { constants: m .constants .into_iter() .map(|(id, tc)| { - - let id = self.fold_canonical_constant_identifier(id)?; + let id = self.fold_canonical_constant_identifier(id); let constant = match tc { TypedConstantSymbol::There(imported_id) => { // visit the imported symbol. This triggers visiting the corresponding module if needed - let imported_id = self.fold_canonical_constant_identifier(imported_id)?; - // after that, the constant must have been defined defined in the global map. It is already reduced - // to the maximum, so running propagation isn't required + let imported_id = self.fold_canonical_constant_identifier(imported_id); + // after that, the constant must have been defined defined in the global map self.get_constant(&imported_id).unwrap() } - TypedConstantSymbol::Here(c) => { - let non_propagated_constant = fold_constant(self, c)?; - // folding the constant above only reduces it to an expression containing only literals, not to a single literal. - // propagating with an empty map of constants reduces it to the maximum - Propagator::with_constants(&mut HashMap::default()) - .fold_constant(non_propagated_constant) - .unwrap() - } + TypedConstantSymbol::Here(c) => fold_constant(self, c), }; + self.constants + .get_mut(&self.location) + .unwrap() + .insert(id.id, constant.clone()); - if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(constant.ty.clone()).unwrap() == constant.expression.get_type() { - // add to the constant map - self.constants - .get_mut(&self.location) - .unwrap() - .insert(id.id, constant.clone()); - - Ok(( - id, - TypedConstantSymbol::Here(constant), - )) - } else { - Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant.expression.get_type(), id.id, constant.ty))) - } + (id, TypedConstantSymbol::Here(constant)) }) - .collect::, _>>()?, + .collect(), functions: m .functions - }) + .into_iter() + .map(|(key, fun)| { + ( + self.fold_declaration_function_key(key), + self.fold_function_symbol(fun), + ) + }) + .collect(), + } } } @@ -218,7 +183,7 @@ mod tests { let program = ConstantInliner::inline(program); - assert_eq!(program, Ok(expected_program)) + assert_eq!(program, expected_program) } #[test] @@ -276,7 +241,7 @@ mod tests { let program = ConstantInliner::inline(program); - assert_eq!(program, Ok(expected_program)) + assert_eq!(program, expected_program) } #[test] @@ -337,7 +302,7 @@ mod tests { let program = ConstantInliner::inline(program); - assert_eq!(program, Ok(expected_program)) + assert_eq!(program, expected_program) } #[test] @@ -418,7 +383,7 @@ mod tests { let program = ConstantInliner::inline(program); - assert_eq!(program, Ok(expected_program)) + assert_eq!(program, expected_program) } #[test] @@ -492,7 +457,7 @@ mod tests { let program = ConstantInliner::inline(program); - assert_eq!(program, Ok(expected_program)) + assert_eq!(program, expected_program) } #[test] @@ -622,6 +587,6 @@ mod tests { .collect(), }; - assert_eq!(program, Ok(expected_program)) + assert_eq!(program, expected_program) } } diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 2285115b4..3b8751f1b 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -40,13 +40,6 @@ pub enum Error { Reducer(self::reducer::Error), Propagation(self::propagation::Error), NonConstantArgument(self::constant_argument_checker::Error), - ConstantInliner(self::constant_inliner::Error), -} - -impl From for Error { - fn from(e: self::constant_inliner::Error) -> Self { - Error::ConstantInliner(e) - } } impl From for Error { @@ -73,7 +66,6 @@ impl fmt::Display for Error { Error::Reducer(e) => write!(f, "{}", e), Error::Propagation(e) => write!(f, "{}", e), Error::NonConstantArgument(e) => write!(f, "{}", e), - Error::ConstantInliner(e) => write!(f, "{}", e), } } } @@ -82,7 +74,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> { // inline user-defined constants log::debug!("Static analyser: Inline constants"); - let r = ConstantInliner::inline(self).map_err(Error::from)?; + let r = ConstantInliner::inline(self); log::trace!("\n{}", r); // isolate branches diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 62fbced18..4d4b45073 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -211,10 +211,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { &self.program, )?; - if inlined_wrapper.statements.len() > 1 { - return Err(Error::ConstantReduction(key.id.to_string(), key.module)); - }; - if let TypedStatement::Return(mut expressions) = inlined_wrapper.statements.pop().unwrap() { diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 0ff64a6d6..6efbaf4c6 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -111,6 +111,12 @@ pub struct CanonicalConstantIdentifier<'ast> { pub id: ConstantIdentifier<'ast>, } +impl<'ast> fmt::Display for CanonicalConstantIdentifier<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}/{}", self.module.display(), self.id) + } +} + impl<'ast> CanonicalConstantIdentifier<'ast> { pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self { CanonicalConstantIdentifier { module, id } From 5f7c03dde2f17aeab39333fddb2c049b5b81e614 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 6 Sep 2021 17:31:00 +0200 Subject: [PATCH 37/78] detect wrong type in reducer (moved over from const inliner) --- .../src/static_analysis/reducer/mod.rs | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 4d4b45073..9c5708104 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -189,10 +189,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { constants: m .constants .into_iter() - .map(|(key, tc)| match tc { + .map(|(id, tc)| match tc { TypedConstantSymbol::Here(c) => { let c = self.fold_constant(c)?; + let ty = c.ty; + // replace the existing constants in this expression let constant_replaced_expression = self.fold_expression(c.expression)?; @@ -202,7 +204,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { statements: vec![TypedStatement::Return(vec![ constant_replaced_expression, ])], - signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]), + signature: DeclarationSignature::new().outputs(vec![ty.clone()]), }; let mut inlined_wrapper = reduce_function( @@ -220,22 +222,27 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { use crate::typed_absy::Constant; if !constant_expression.is_constant() { return Err(Error::ConstantReduction( - key.id.to_string(), - key.module, + id.id.to_string(), + id.module, )); }; - self.constants - .insert(key.clone(), constant_expression.clone()); - - Ok(( - key, - TypedConstantSymbol::Here(TypedConstant { - expression: constant_expression, - ty: c.ty, - }), - )) + + use crate::typed_absy::Typed; + if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(ty.clone()).unwrap() == constant_expression.get_type() { + // add to the constant map + self.constants.insert(id.clone(), constant_expression.clone()); + Ok(( + id, + TypedConstantSymbol::Here(TypedConstant { + expression: constant_expression, + ty, + }), + )) + } else { + Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, ty))) + } } else { - Err(Error::ConstantReduction(key.id.to_string(), key.module)) + Err(Error::ConstantReduction(id.id.to_string(), id.module)) } } _ => unreachable!("all constants should be local"), @@ -272,6 +279,7 @@ pub enum Error { NoProgress, LoopTooLarge(u128), ConstantReduction(String, OwnedTypedModuleId), + Type(String), } impl fmt::Display for Error { @@ -286,6 +294,7 @@ impl fmt::Display for Error { Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"), Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE), Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()), + Error::Type(message) => write!(f, "{}", message), } } } From ece9694d6bfcdd69d9afd9996bba5ffd909cc332 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 6 Sep 2021 19:46:41 +0200 Subject: [PATCH 38/78] clippy --- zokrates_core/src/static_analysis/reducer/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 10719b6aa..d2a6aeb1b 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -29,8 +29,8 @@ use crate::typed_absy::{ FunctionCall, FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, StructExpression, StructExpressionInner, StructType, TypedConstant, TypedConstantSymbol, TypedExpression, TypedExpressionList, TypedExpressionListInner, - TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, - UExpressionInner, Variable, + TypedFunction, TypedFunctionSymbol, TypedModule, TypedModuleId, TypedProgram, TypedStatement, + UExpression, UExpressionInner, Variable, }; use std::convert::{TryFrom, TryInto}; @@ -73,7 +73,7 @@ impl<'ast, T> ConstantCallsInliner<'ast, T> { prev } - fn treated(&self, id: &OwnedTypedModuleId) -> bool { + fn treated(&self, id: &TypedModuleId) -> bool { self.treated.contains(id) } } From 969bd1c015276f1ca4093734bd1342e9d02a9f24 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 7 Sep 2021 13:50:12 +0200 Subject: [PATCH 39/78] keep module in program during visit --- zokrates_core/src/static_analysis/reducer/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index d2a6aeb1b..1e4a5d4df 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -88,7 +88,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); - let m = self.program.modules.remove(&id).unwrap(); + let m = self.program.modules.get(&id).cloned().unwrap(); let m = self.fold_module(m)?; self.program.modules.insert(id.clone(), m); self.change_location(current_m_id); From 60f918945871bba1425875582b51ea308f404835 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 8 Sep 2021 15:51:28 +0200 Subject: [PATCH 40/78] add range semantics to docs --- changelogs/unreleased/992-dark64 | 1 + zokrates_book/src/language/control_flow.md | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelogs/unreleased/992-dark64 diff --git a/changelogs/unreleased/992-dark64 b/changelogs/unreleased/992-dark64 new file mode 100644 index 000000000..bdd50880b --- /dev/null +++ b/changelogs/unreleased/992-dark64 @@ -0,0 +1 @@ +Add range semantics to docs \ No newline at end of file diff --git a/zokrates_book/src/language/control_flow.md b/zokrates_book/src/language/control_flow.md index 1cdd412d8..9e47d289f 100644 --- a/zokrates_book/src/language/control_flow.md +++ b/zokrates_book/src/language/control_flow.md @@ -52,6 +52,7 @@ For loops are available with the following syntax: ``` The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. They can depend on generic parameters. +The range is half-open, meaning it is bounded inclusively below and exclusively above. The range `start..end` contains all values within `start <= x < end`. The range is empty if `start >= end`. > For loops are only syntactic sugar for repeating a block of statements many times. No condition of the type `index < max` is being checked at run-time after each iteration. Instead, at compile-time, the index is incremented and the block is executed again. Therefore, assigning to the loop index does not have any influence on the number of iterations performed and is considered bad practice. From 187a1e834b09f2d5dfd5e1c61512c8f6793a6adf Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 10 Sep 2021 23:27:19 +0200 Subject: [PATCH 41/78] refactor for ordered symbols, clean, propagate constants globally to cover constants in function key --- zokrates_cli/examples/call_in_const.zok | 6 +- .../compile_errors/ambiguous_generic_call.zok | 14 + .../ambiguous_generic_call_too_strict.zok | 14 + zokrates_core/src/semantics.rs | 115 +-- .../src/static_analysis/constant_inliner.rs | 592 ------------ .../src/static_analysis/constant_resolver.rs | 844 ++++++++++++++++++ .../static_analysis/flatten_complex_types.rs | 10 +- zokrates_core/src/static_analysis/mod.rs | 4 +- .../src/static_analysis/propagation.rs | 26 +- .../reducer/constants_reader.rs | 155 ++++ .../src/static_analysis/reducer/inline.rs | 35 +- .../src/static_analysis/reducer/mod.rs | 439 ++++----- zokrates_core/src/typed_absy/abi.rs | 17 +- zokrates_core/src/typed_absy/folder.rs | 78 +- zokrates_core/src/typed_absy/mod.rs | 169 +++- zokrates_core/src/typed_absy/result_folder.rs | 70 +- 16 files changed, 1542 insertions(+), 1046 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok create mode 100644 zokrates_cli/examples/compile_errors/ambiguous_generic_call_too_strict.zok delete mode 100644 zokrates_core/src/static_analysis/constant_inliner.rs create mode 100644 zokrates_core/src/static_analysis/constant_resolver.rs create mode 100644 zokrates_core/src/static_analysis/reducer/constants_reader.rs diff --git a/zokrates_cli/examples/call_in_const.zok b/zokrates_cli/examples/call_in_const.zok index bf3cac375..d6e73dc3e 100644 --- a/zokrates_cli/examples/call_in_const.zok +++ b/zokrates_cli/examples/call_in_const.zok @@ -1,5 +1,9 @@ from "./call_in_const_aux.zok" import A, foo, F -const field[A] Y = [...foo::(F)[..A - 1], 1] + +def bar(field[A] x) -> field[A]: + return x + +const field[A] Y = [...bar(foo::(F))[..A - 1], 1] def main(field[A] X): assert(X == Y) diff --git a/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok b/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok new file mode 100644 index 000000000..5224a4be0 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok @@ -0,0 +1,14 @@ +// this should not compile, as A == B + +const u32 A = 2 +const u32 B = 1 + +def foo(field[A] a) -> bool: + return true + +def foo(field[B] a) -> bool: + return true + +def main(): + assert(foo([1])) + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/ambiguous_generic_call_too_strict.zok b/zokrates_cli/examples/compile_errors/ambiguous_generic_call_too_strict.zok new file mode 100644 index 000000000..5abd8acdb --- /dev/null +++ b/zokrates_cli/examples/compile_errors/ambiguous_generic_call_too_strict.zok @@ -0,0 +1,14 @@ +// this should actually compile, as A != B + +const u32 A = 2 +const u32 B = 1 + +def foo(field[A] a) -> bool: + return true + +def foo(field[B] a) -> bool: + return true + +def main(): + assert(foo([1])) + return \ No newline at end of file diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index ef2237082..44eeb7100 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -517,8 +517,7 @@ impl<'ast, T: Field> Checker<'ast, T> { declaration: SymbolDeclarationNode<'ast>, module_id: &ModuleId, state: &mut State<'ast, T>, - functions: &mut TypedFunctionSymbols<'ast, T>, - constants: &mut TypedConstantSymbols<'ast, T>, + symbols: &mut TypedSymbolDeclarations<'ast, T>, symbol_unifier: &mut SymbolUnifier<'ast, T>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -578,12 +577,14 @@ impl<'ast, T: Field> Checker<'ast, T> { .in_file(module_id), ), true => { - constants.push(( - CanonicalConstantIdentifier::new( - declaration.id, - module_id.into(), - ), - TypedConstantSymbol::Here(c.clone()), + symbols.push(TypedSymbolDeclaration::Constant( + TypedConstantSymbolDeclaration { + id: CanonicalConstantIdentifier::new( + declaration.id, + module_id.into(), + ), + symbol: TypedConstantSymbol::Here(c.clone()), + }, )); self.insert_into_scope(Variable::with_id_and_type( CoreIdentifier::Constant(CanonicalConstantIdentifier::new( @@ -632,14 +633,16 @@ impl<'ast, T: Field> Checker<'ast, T> { ) .signature(funct.signature.clone()), ); - functions.insert( - DeclarationFunctionKey::with_location( - module_id.to_path_buf(), - declaration.id, - ) - .signature(funct.signature.clone()), - TypedFunctionSymbol::Here(funct), - ); + symbols.push(TypedSymbolDeclaration::Function( + TypedFunctionSymbolDeclaration { + key: DeclarationFunctionKey::with_location( + module_id.to_path_buf(), + declaration.id, + ) + .signature(funct.signature.clone()), + symbol: TypedFunctionSymbol::Here(funct), + }, + )); } Err(e) => { errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); @@ -657,13 +660,12 @@ impl<'ast, T: Field> Checker<'ast, T> { .typed_modules .get(&import.module_id) .unwrap() - .functions - .iter() - .filter(|(k, _)| k.id == import.symbol_id) - .map(|(_, v)| DeclarationFunctionKey { + .functions_iter() + .filter(|d| d.key.id == import.symbol_id) + .map(|d| DeclarationFunctionKey { module: import.module_id.to_path_buf(), id: import.symbol_id, - signature: v.signature(&state.typed_modules).clone(), + signature: d.symbol.signature(&state.typed_modules).clone(), }) .collect(); @@ -738,7 +740,10 @@ impl<'ast, T: Field> Checker<'ast, T> { let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id); let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into()); - constants.push((id.clone(), TypedConstantSymbol::There(imported_id))); + symbols.push(TypedSymbolDeclaration::Constant(TypedConstantSymbolDeclaration { + id: id.clone(), + symbol: TypedConstantSymbol::There(imported_id) + })); self.insert_into_scope(Variable::with_id_and_type(CoreIdentifier::Constant(CanonicalConstantIdentifier::new( declaration.id, module_id.into(), @@ -781,10 +786,12 @@ impl<'ast, T: Field> Checker<'ast, T> { let local_key = candidate.clone().id(declaration.id).module(module_id.to_path_buf()); self.functions.insert(local_key.clone()); - functions.insert( - local_key, - TypedFunctionSymbol::There(candidate, - ), + symbols.push( + TypedSymbolDeclaration::Function(TypedFunctionSymbolDeclaration { + key: local_key, + symbol: TypedFunctionSymbol::There(candidate, + ), + }) ); } } @@ -816,11 +823,16 @@ impl<'ast, T: Field> Checker<'ast, T> { DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) .signature(funct.typed_signature()), ); - functions.insert( - DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) + symbols.push(TypedSymbolDeclaration::Function( + TypedFunctionSymbolDeclaration { + key: DeclarationFunctionKey::with_location( + module_id.to_path_buf(), + declaration.id, + ) .signature(funct.typed_signature()), - TypedFunctionSymbol::Flat(funct), - ); + symbol: TypedFunctionSymbol::Flat(funct), + }, + )); } _ => unreachable!(), }; @@ -838,8 +850,7 @@ impl<'ast, T: Field> Checker<'ast, T> { module_id: &ModuleId, state: &mut State<'ast, T>, ) -> Result<(), Vec> { - let mut checked_functions = TypedFunctionSymbols::new(); - let mut checked_constants = TypedConstantSymbols::new(); + let mut checked_symbols = TypedSymbolDeclarations::new(); // check if the module was already removed from the untyped ones let to_insert = match state.modules.remove(module_id) { @@ -860,15 +871,13 @@ impl<'ast, T: Field> Checker<'ast, T> { declaration, module_id, state, - &mut checked_functions, - &mut checked_constants, + &mut checked_symbols, &mut symbol_unifier, )? } Some(TypedModule { - functions: checked_functions, - constants: checked_constants, + symbols: checked_symbols, }) } }; @@ -887,9 +896,9 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { match module - .functions + .symbols .iter() - .filter(|(key, _)| key.id == "main") + .filter(|s| matches!(s, TypedSymbolDeclaration::Function(d) if d.key.id == "main")) .count() { 1 => Ok(()), @@ -3579,17 +3588,17 @@ mod tests { assert_eq!( state.typed_modules.get(&PathBuf::from("bar")), Some(&TypedModule { - functions: vec![( + symbols: vec![TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("bar", "main") .signature(DeclarationSignature::new()), TypedFunctionSymbol::There( DeclarationFunctionKey::with_location("foo", "main") .signature(DeclarationSignature::new()), ) - )] + ) + .into()] .into_iter() .collect(), - constants: TypedConstantSymbols::default() }) ); } @@ -3832,21 +3841,23 @@ mod tests { .typed_modules .get(&*MODULE_ID) .unwrap() - .functions - .contains_key( - &DeclarationFunctionKey::with_location((*MODULE_ID).clone(), "foo") - .signature(DeclarationSignature::new()) - )); + .functions_iter() + .find(|d| d.key + == DeclarationFunctionKey::with_location((*MODULE_ID).clone(), "foo") + .signature(DeclarationSignature::new())) + .is_some()); + assert!(state .typed_modules .get(&*MODULE_ID) .unwrap() - .functions - .contains_key( - &DeclarationFunctionKey::with_location((*MODULE_ID).clone(), "foo").signature( - DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]) - ) - )) + .functions_iter() + .find(|d| d.key + == DeclarationFunctionKey::with_location((*MODULE_ID).clone(), "foo") + .signature( + DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]) + )) + .is_some()); } #[test] diff --git a/zokrates_core/src/static_analysis/constant_inliner.rs b/zokrates_core/src/static_analysis/constant_inliner.rs deleted file mode 100644 index 7e38b9f43..000000000 --- a/zokrates_core/src/static_analysis/constant_inliner.rs +++ /dev/null @@ -1,592 +0,0 @@ -// Static analysis step to replace all imported constants with the imported value -// This does *not* reduce constants to their literal value -// This step cannot fail as the imports were checked during semantics - -use crate::typed_absy::folder::*; -use crate::typed_absy::*; -use std::collections::HashMap; -use zokrates_field::Field; - -// a map of the canonical constants in this program. with all imported constants reduced to their canonical value -type ProgramConstants<'ast, T> = - HashMap, TypedConstant<'ast, T>>>; - -pub struct ConstantInliner<'ast, T> { - modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, - constants: ProgramConstants<'ast, T>, -} - -impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { - pub fn new( - modules: TypedModules<'ast, T>, - location: OwnedTypedModuleId, - constants: ProgramConstants<'ast, T>, - ) -> Self { - ConstantInliner { - modules, - location, - constants, - } - } - pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - let constants = ProgramConstants::new(); - let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants); - inliner.fold_program(p) - } - - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { - let prev = self.location.clone(); - self.location = location; - self.constants.entry(self.location.clone()).or_default(); - prev - } - - fn treated(&self, id: &TypedModuleId) -> bool { - self.constants.contains_key(id) - } - - fn get_constant( - &self, - id: &CanonicalConstantIdentifier<'ast>, - ) -> Option> { - self.constants - .get(&id.module) - .and_then(|constants| constants.get(&id.id)) - .cloned() - } -} - -impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { - fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - self.fold_module_id(p.main.clone()); - - TypedProgram { - modules: std::mem::take(&mut self.modules), - ..p - } - } - - fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { - // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet - if !self.treated(&id) { - let current_m_id = self.change_location(id.clone()); - let m = self.modules.remove(&id).unwrap(); - let m = self.fold_module(m); - self.modules.insert(id.clone(), m); - self.change_location(current_m_id); - } - id - } - - fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> { - TypedModule { - constants: m - .constants - .into_iter() - .map(|(id, tc)| { - let id = self.fold_canonical_constant_identifier(id); - - let constant = match tc { - TypedConstantSymbol::There(imported_id) => { - // visit the imported symbol. This triggers visiting the corresponding module if needed - let imported_id = self.fold_canonical_constant_identifier(imported_id); - // after that, the constant must have been defined defined in the global map - self.get_constant(&imported_id).unwrap() - } - TypedConstantSymbol::Here(c) => fold_constant(self, c), - }; - self.constants - .get_mut(&self.location) - .unwrap() - .insert(id.id, constant.clone()); - - (id, TypedConstantSymbol::Here(constant)) - }) - .collect(), - functions: m - .functions - .into_iter() - .map(|(key, fun)| { - ( - self.fold_declaration_function_key(key), - self.fold_function_symbol(fun), - ) - }) - .collect(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::typed_absy::types::DeclarationSignature; - use crate::typed_absy::{ - DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression, - GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, - TypedStatement, - }; - use zokrates_field::Bn128Field; - - #[test] - fn inline_const_field() { - // const field a = 1 - // - // def main() -> field: - // return a - - let const_id = "a"; - let main: TypedFunction = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(Identifier::from(const_id)).into(), - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }; - - let constants: TypedConstantSymbols<_> = vec![( - CanonicalConstantIdentifier::new(const_id, "main".into()), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))), - DeclarationType::FieldElement, - )), - )] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(main), - )] - .into_iter() - .collect(), - constants: constants.clone(), - }, - )] - .into_iter() - .collect(), - }; - - let expected_program = program.clone(); - - let program = ConstantInliner::inline(program); - - assert_eq!(program, expected_program) - } - - #[test] - fn inline_const_boolean() { - // const bool a = true - // - // def main() -> bool: - // return main.zok/a - - let const_id = CanonicalConstantIdentifier::new("a", "main".into()); - let main: TypedFunction = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( - Identifier::from(const_id.clone()), - ) - .into()])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Boolean]), - }; - - let constants: TypedConstantSymbols<_> = vec![( - const_id, - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Boolean(BooleanExpression::Value(true)), - DeclarationType::Boolean, - )), - )] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Boolean]), - ), - TypedFunctionSymbol::Here(main), - )] - .into_iter() - .collect(), - constants: constants.clone(), - }, - )] - .into_iter() - .collect(), - }; - - let expected_program = program.clone(); - - let program = ConstantInliner::inline(program); - - assert_eq!(program, expected_program) - } - - #[test] - fn inline_const_uint() { - // const u32 a = 0x00000001 - // - // def main() -> u32: - // return a - - let const_id = CanonicalConstantIdentifier::new("a", "main".into()); - let main: TypedFunction = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier( - Identifier::from(const_id.clone()), - ) - .annotate(UBitwidth::B32) - .into()])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), - }; - - let constants: TypedConstantSymbols<_> = vec![( - const_id, - TypedConstantSymbol::Here(TypedConstant::new( - UExpressionInner::Value(1u128) - .annotate(UBitwidth::B32) - .into(), - DeclarationType::Uint(UBitwidth::B32), - )), - )] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), - ), - TypedFunctionSymbol::Here(main), - )] - .into_iter() - .collect(), - constants: constants.clone(), - }, - )] - .into_iter() - .collect(), - }; - - let expected_program = program.clone(); - - let program = ConstantInliner::inline(program); - - assert_eq!(program, expected_program) - } - - #[test] - fn inline_const_field_array() { - // const field[2] a = [2, 2] - // - // def main() -> field: - // return a[0] + a[1] - - let const_id = CanonicalConstantIdentifier::new("a", "main".into()); - let main: TypedFunction = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( - FieldElementExpression::select( - ArrayExpressionInner::Identifier(Identifier::from(const_id.clone())) - .annotate(GType::FieldElement, 2usize), - UExpressionInner::Value(0u128).annotate(UBitwidth::B32), - ) - .into(), - FieldElementExpression::select( - ArrayExpressionInner::Identifier(Identifier::from(const_id.clone())) - .annotate(GType::FieldElement, 2usize), - UExpressionInner::Value(1u128).annotate(UBitwidth::B32), - ) - .into(), - ) - .into()])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }; - - let constants: TypedConstantSymbols<_> = vec![( - const_id.clone(), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::Array( - ArrayExpressionInner::Value( - vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ] - .into(), - ) - .annotate(GType::FieldElement, 2usize), - ), - DeclarationType::Array(DeclarationArrayType::new( - DeclarationType::FieldElement, - 2u32, - )), - )), - )] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(main), - )] - .into_iter() - .collect(), - constants: constants.clone(), - }, - )] - .into_iter() - .collect(), - }; - - let expected_program = program.clone(); - - let program = ConstantInliner::inline(program); - - assert_eq!(program, expected_program) - } - - #[test] - fn inline_nested_const_field() { - // const field a = 1 - // const field b = a + 1 - // - // def main() -> field: - // return b - - let const_a_id = CanonicalConstantIdentifier::new("a", "main".into()); - let const_b_id = CanonicalConstantIdentifier::new("a", "main".into()); - - let main: TypedFunction = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(Identifier::from(const_b_id.clone())).into(), - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }; - - let program = TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(main), - )] - .into_iter() - .collect(), - constants: vec![ - ( - const_a_id.clone(), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(1), - )), - DeclarationType::FieldElement, - )), - ), - ( - const_b_id.clone(), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from( - const_a_id.clone(), - )), - box FieldElementExpression::Number(Bn128Field::from(1)), - )), - DeclarationType::FieldElement, - )), - ), - ] - .into_iter() - .collect(), - }, - )] - .into_iter() - .collect(), - }; - - let expected_program = program.clone(); - - let program = ConstantInliner::inline(program); - - assert_eq!(program, expected_program) - } - - #[test] - fn inline_imported_constant() { - // --------------------- - // module `foo` - // -------------------- - // const field FOO = 42 - // - // def main(): - // return - // - // --------------------- - // module `main` - // --------------------- - // from "foo" import FOO - // - // def main() -> field: - // return FOO - - let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into()); - let foo_module = TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("foo", "main") - .signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![], - signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]), - }), - )] - .into_iter() - .collect(), - constants: vec![( - foo_const_id.clone(), - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(42), - )), - DeclarationType::FieldElement, - )), - )] - .into_iter() - .collect(), - }; - - let main_const_id = CanonicalConstantIdentifier::new("FOO", "main".into()); - let main_module = TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(Identifier::from(main_const_id.clone())) - .into(), - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }), - )] - .into_iter() - .collect(), - constants: vec![( - main_const_id.clone(), - TypedConstantSymbol::There(foo_const_id), - )] - .into_iter() - .collect(), - }; - - let program = TypedProgram { - main: "main".into(), - modules: vec![ - ("main".into(), main_module), - ("foo".into(), foo_module.clone()), - ] - .into_iter() - .collect(), - }; - - let program = ConstantInliner::inline(program); - let expected_main_module = TypedModule { - functions: vec![( - DeclarationFunctionKey::with_location("main", "main").signature( - DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier(Identifier::from(main_const_id.clone())) - .into(), - ])], - signature: DeclarationSignature::new() - .inputs(vec![]) - .outputs(vec![DeclarationType::FieldElement]), - }), - )] - .into_iter() - .collect(), - constants: vec![( - main_const_id, - TypedConstantSymbol::Here(TypedConstant::new( - TypedExpression::FieldElement(FieldElementExpression::Number( - Bn128Field::from(42), - )), - DeclarationType::FieldElement, - )), - )] - .into_iter() - .collect(), - }; - - let expected_program: TypedProgram = TypedProgram { - main: "main".into(), - modules: vec![ - ("main".into(), expected_main_module), - ("foo".into(), foo_module), - ] - .into_iter() - .collect(), - }; - - assert_eq!(program, expected_program) - } -} diff --git a/zokrates_core/src/static_analysis/constant_resolver.rs b/zokrates_core/src/static_analysis/constant_resolver.rs new file mode 100644 index 000000000..b6c4f1368 --- /dev/null +++ b/zokrates_core/src/static_analysis/constant_resolver.rs @@ -0,0 +1,844 @@ +// Static analysis step to replace all imported constants with the imported value +// This does *not* reduce constants to their literal value +// This step cannot fail as the imports were checked during semantics + +use crate::typed_absy::folder::*; +use crate::typed_absy::*; +use std::collections::HashMap; +use zokrates_field::Field; + +// a map of the canonical constants in this program. with all imported constants reduced to their canonical value +type ProgramConstants<'ast, T> = + HashMap, TypedConstant<'ast, T>>>; + +pub struct ConstantInliner<'ast, T> { + modules: TypedModules<'ast, T>, + location: OwnedTypedModuleId, + constants: ProgramConstants<'ast, T>, +} + +impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { + pub fn new( + modules: TypedModules<'ast, T>, + location: OwnedTypedModuleId, + constants: ProgramConstants<'ast, T>, + ) -> Self { + ConstantInliner { + modules, + location, + constants, + } + } + pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + let constants = ProgramConstants::new(); + let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants); + inliner.fold_program(p) + } + + fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + let prev = self.location.clone(); + self.location = location; + self.constants.entry(self.location.clone()).or_default(); + prev + } + + fn treated(&self, id: &TypedModuleId) -> bool { + self.constants.contains_key(id) + } + + fn get_constant( + &self, + id: &CanonicalConstantIdentifier<'ast>, + ) -> Option> { + self.constants + .get(&id.module) + .and_then(|constants| constants.get(&id.id)) + .cloned() + } +} + +impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { + fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + self.fold_module_id(p.main.clone()); + + TypedProgram { + modules: std::mem::take(&mut self.modules), + ..p + } + } + + fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId { + // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet + if !self.treated(&id) { + let current_m_id = self.change_location(id.clone()); + let m = self.modules.remove(&id).unwrap(); + let m = self.fold_module(m); + self.modules.insert(id.clone(), m); + self.change_location(current_m_id); + } + id + } + + fn fold_constant_symbol_declaration( + &mut self, + c: TypedConstantSymbolDeclaration<'ast, T>, + ) -> TypedConstantSymbolDeclaration<'ast, T> { + let id = self.fold_canonical_constant_identifier(c.id); + + let constant = match c.symbol { + TypedConstantSymbol::There(imported_id) => { + // visit the imported symbol. This triggers visiting the corresponding module if needed + let imported_id = self.fold_canonical_constant_identifier(imported_id); + // after that, the constant must have been defined defined in the global map + self.get_constant(&imported_id).unwrap() + } + TypedConstantSymbol::Here(c) => fold_constant(self, c), + }; + self.constants + .get_mut(&self.location) + .unwrap() + .insert(id.id, constant.clone()); + + TypedConstantSymbolDeclaration { + id, + symbol: TypedConstantSymbol::Here(constant), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::typed_absy::types::DeclarationSignature; + use crate::typed_absy::{ + DeclarationArrayType, DeclarationFunctionKey, DeclarationType, FieldElementExpression, + GType, Identifier, TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, + TypedStatement, + }; + use zokrates_field::Bn128Field; + + #[test] + fn inline_const_field() { + // in the absence of imports, a module is left unchanged + + // const field a = 1 + // + // def main() -> field: + // return a + + let const_id = "a"; + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from(const_id)).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + CanonicalConstantIdentifier::new(const_id, "main".into()), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + ) + .into(), + ], + }, + )] + .into_iter() + .collect(), + }; + + let expected_program = program.clone(); + + let program = ConstantInliner::inline(program); + + assert_eq!(program, expected_program) + } + + #[test] + fn no_op_const_boolean() { + // in the absence of imports, a module is left unchanged + + // const bool a = true + // + // def main() -> bool: + // return main.zok/a + + let const_id = CanonicalConstantIdentifier::new("a", "main".into()); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier( + Identifier::from(const_id.clone()), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + const_id, + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Boolean(BooleanExpression::Value(true)), + DeclarationType::Boolean, + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Boolean]), + ), + TypedFunctionSymbol::Here(main), + ) + .into(), + ], + }, + )] + .into_iter() + .collect(), + }; + + let expected_program = program.clone(); + + let program = ConstantInliner::inline(program); + + assert_eq!(program, expected_program) + } + + #[test] + fn no_op_const_uint() { + // in the absence of imports, a module is left unchanged + + // const u32 a = 0x00000001 + // + // def main() -> u32: + // return a + + let const_id = CanonicalConstantIdentifier::new("a", "main".into()); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier( + Identifier::from(const_id.clone()), + ) + .annotate(UBitwidth::B32) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + const_id, + TypedConstantSymbol::Here(TypedConstant::new( + UExpressionInner::Value(1u128) + .annotate(UBitwidth::B32) + .into(), + DeclarationType::Uint(UBitwidth::B32), + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::Uint(UBitwidth::B32)]), + ), + TypedFunctionSymbol::Here(main), + ) + .into(), + ], + }, + )] + .into_iter() + .collect(), + }; + + let expected_program = program.clone(); + + let program = ConstantInliner::inline(program); + + assert_eq!(program, expected_program) + } + + #[test] + fn no_op_const_field_array() { + // in the absence of imports, a module is left unchanged + + // const field[2] a = [2, 2] + // + // def main() -> field: + // return a[0] + a[1] + + let const_id = CanonicalConstantIdentifier::new("a", "main".into()); + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add( + FieldElementExpression::select( + ArrayExpressionInner::Identifier(Identifier::from(const_id.clone())) + .annotate(GType::FieldElement, 2usize), + UExpressionInner::Value(0u128).annotate(UBitwidth::B32), + ) + .into(), + FieldElementExpression::select( + ArrayExpressionInner::Identifier(Identifier::from(const_id.clone())) + .annotate(GType::FieldElement, 2usize), + UExpressionInner::Value(1u128).annotate(UBitwidth::B32), + ) + .into(), + ) + .into()])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Array( + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)) + .into(), + FieldElementExpression::Number(Bn128Field::from(2)) + .into(), + ] + .into(), + ) + .annotate(GType::FieldElement, 2usize), + ), + DeclarationType::Array(DeclarationArrayType::new( + DeclarationType::FieldElement, + 2u32, + )), + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + ) + .into(), + ], + }, + )] + .into_iter() + .collect(), + }; + + let expected_program = program.clone(); + + let program = ConstantInliner::inline(program); + + assert_eq!(program, expected_program) + } + + #[test] + fn no_op_nested_const_field() { + // const field a = 1 + // const field b = a + 1 + // + // def main() -> field: + // return b + + let const_a_id = CanonicalConstantIdentifier::new("a", "main".into()); + let const_b_id = CanonicalConstantIdentifier::new("a", "main".into()); + + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from(const_b_id.clone())).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + const_a_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(1), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedConstantSymbolDeclaration::new( + const_b_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Add( + box FieldElementExpression::Identifier(Identifier::from( + const_a_id.clone(), + )), + box FieldElementExpression::Number(Bn128Field::from(1)), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + ) + .into(), + ], + }, + )] + .into_iter() + .collect(), + }; + + let expected_program = program.clone(); + + let program = ConstantInliner::inline(program); + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_imported_constant() { + // --------------------- + // module `foo` + // -------------------- + // const field FOO = 42 + // const field BAR = FOO + // + // def main(): + // return + // + // --------------------- + // module `main` + // --------------------- + // from "foo" import BAR + // + // def main() -> field: + // return FOO + + // Should be resolved to + + // --------------------- + // module `foo` + // -------------------- + // const field BAR = ./foo.zok/FOO + // + // def main(): + // return + // + // --------------------- + // module `main` + // --------------------- + // const field FOO = 42 + // + // def main() -> field: + // return FOO + + let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into()); + let bar_const_id = CanonicalConstantIdentifier::new("BAR", "foo".into()); + let foo_module = TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + foo_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(42), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedConstantSymbolDeclaration::new( + bar_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Identifier( + foo_const_id.clone().into(), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("foo", "main") + .signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![], + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]), + }), + ) + .into(), + ], + }; + + let main_const_id = CanonicalConstantIdentifier::new("FOO", "main".into()); + let main_module = TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + main_const_id.clone(), + TypedConstantSymbol::There(bar_const_id), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from( + main_const_id.clone(), + )) + .into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + ) + .into(), + ], + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), main_module), + ("foo".into(), foo_module.clone()), + ] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + let expected_main_module = TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + main_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Identifier( + foo_const_id.into(), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from( + main_const_id.clone(), + )) + .into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + ) + .into(), + ], + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), expected_main_module), + ("foo".into(), foo_module), + ] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } + + #[test] + fn inline_imported_constant_with_generics() { + // --------------------- + // module `foo` + // -------------------- + // const field FOO = 2 + // const field[FOO] BAR = [1; FOO] + // + // def main(): + // return + // + // --------------------- + // module `main` + // --------------------- + // from "foo" import FOO + // from "foo" import BAR + // const field[FOO] BAZ = BAR + // + // def main() -> field: + // return FOO + + // Should be resolved to + + // --------------------- + // module `foo` + // -------------------- + // const field FOO = 2 + // const field[FOO] BAR = [1; FOO] + // + // def main(): + // return + // + // --------------------- + // module `main` + // --------------------- + // const FOO = 2 + // const BAR = [1; ./foo.zok/FOO] + // const field[FOO] BAZ = BAR + // + // def main() -> field: + // return FOO + + let foo_const_id = CanonicalConstantIdentifier::new("FOO", "foo".into()); + let bar_const_id = CanonicalConstantIdentifier::new("BAR", "foo".into()); + let foo_module = TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + foo_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::FieldElement(FieldElementExpression::Number( + Bn128Field::from(2), + )), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedConstantSymbolDeclaration::new( + bar_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Array( + ArrayExpressionInner::Repeat( + box FieldElementExpression::Number(Bn128Field::from(1)).into(), + box UExpression::from(foo_const_id.clone()), + ) + .annotate(Type::FieldElement, foo_const_id.clone()), + ), + DeclarationType::Array(DeclarationArrayType::new( + DeclarationType::FieldElement, + DeclarationConstant::Constant(foo_const_id.clone()), + )), + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("foo", "main") + .signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![], + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]), + }), + ) + .into(), + ], + }; + + let main_foo_const_id = CanonicalConstantIdentifier::new("FOO", "main".into()); + let main_bar_const_id = CanonicalConstantIdentifier::new("BAR", "main".into()); + let main_baz_const_id = CanonicalConstantIdentifier::new("BAZ", "main".into()); + + let main_module = TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + main_foo_const_id.clone(), + TypedConstantSymbol::There(foo_const_id.clone()), + ) + .into(), + TypedConstantSymbolDeclaration::new( + main_bar_const_id.clone(), + TypedConstantSymbol::There(bar_const_id), + ) + .into(), + TypedConstantSymbolDeclaration::new( + main_baz_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Array( + ArrayExpressionInner::Identifier(main_bar_const_id.clone().into()) + .annotate(Type::FieldElement, main_foo_const_id.clone()), + ), + DeclarationType::Array(DeclarationArrayType::new( + DeclarationType::FieldElement, + DeclarationConstant::Constant(foo_const_id.clone()), + )), + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from( + main_foo_const_id.clone(), + )) + .into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + ) + .into(), + ], + }; + + let program = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), main_module), + ("foo".into(), foo_module.clone()), + ] + .into_iter() + .collect(), + }; + + let program = ConstantInliner::inline(program); + let expected_main_module = TypedModule { + symbols: vec![ + TypedConstantSymbolDeclaration::new( + main_foo_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + FieldElementExpression::Number(Bn128Field::from(2)).into(), + DeclarationType::FieldElement, + )), + ) + .into(), + TypedConstantSymbolDeclaration::new( + main_bar_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Array( + ArrayExpressionInner::Repeat( + box FieldElementExpression::Number(Bn128Field::from(1)).into(), + box UExpression::from(foo_const_id.clone()), + ) + .annotate(Type::FieldElement, foo_const_id.clone()), + ), + DeclarationType::Array(DeclarationArrayType::new( + DeclarationType::FieldElement, + DeclarationConstant::Constant(foo_const_id.clone()), + )), + )), + ) + .into(), + TypedConstantSymbolDeclaration::new( + main_baz_const_id.clone(), + TypedConstantSymbol::Here(TypedConstant::new( + TypedExpression::Array( + ArrayExpressionInner::Identifier(main_bar_const_id.into()) + .annotate(Type::FieldElement, main_foo_const_id.clone()), + ), + DeclarationType::Array(DeclarationArrayType::new( + DeclarationType::FieldElement, + DeclarationConstant::Constant(foo_const_id.clone()), + )), + )), + ) + .into(), + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier(Identifier::from( + main_foo_const_id.clone(), + )) + .into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), + }), + ) + .into(), + ], + }; + + let expected_program: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![ + ("main".into(), expected_main_module), + ("foo".into(), foo_module), + ] + .into_iter() + .collect(), + }; + + assert_eq!(program, expected_program) + } +} diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 4113c74c6..0dbca3687 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -1139,14 +1139,12 @@ fn fold_program<'ast, T: Field>( let main_module = p.modules.remove(&p.main).unwrap(); let main_function = main_module - .functions - .into_iter() - .find(|(key, _)| key.id == "main") + .into_functions_iter() + .find(|d| d.key.id == "main") .unwrap() - .1; - + .symbol; let main_function = match main_function { - typed_absy::TypedFunctionSymbol::Here(f) => f, + typed_absy::TypedFunctionSymbol::Here(main) => main, _ => unreachable!(), }; diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index b859d0f91..5a07f1437 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -6,7 +6,7 @@ mod branch_isolator; mod constant_argument_checker; -mod constant_inliner; +mod constant_resolver; mod flat_propagation; mod flatten_complex_types; mod propagation; @@ -26,7 +26,7 @@ use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; use crate::compile::CompileConfig; use crate::ir::Prog; -use crate::static_analysis::constant_inliner::ConstantInliner; +use crate::static_analysis::constant_resolver::ConstantInliner; use crate::static_analysis::zir_propagation::ZirPropagator; use crate::typed_absy::{abi::Abi, TypedProgram}; use crate::zir::ZirProgram; diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 10f0eeaa4..cb09a669e 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -150,21 +150,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }) } - fn fold_module(&mut self, m: TypedModule<'ast, T>) -> Result, Error> { - Ok(TypedModule { - functions: m - .functions - .into_iter() - .map(|(key, fun)| { - if key.id == "main" { - self.fold_function_symbol(fun).map(|f| (key, f)) - } else { - Ok((key, fun)) - } - }) - .collect::>()?, - ..m - }) + fn fold_function_symbol_declaration( + &mut self, + s: TypedFunctionSymbolDeclaration<'ast, T>, + ) -> Result, Error> { + if s.key.id == "main" { + let key = s.key; + self.fold_function_symbol(s.symbol) + .map(|f| TypedFunctionSymbolDeclaration { key, symbol: f }) + } else { + Ok(s) + } } fn fold_function( diff --git a/zokrates_core/src/static_analysis/reducer/constants_reader.rs b/zokrates_core/src/static_analysis/reducer/constants_reader.rs new file mode 100644 index 000000000..9f39140a1 --- /dev/null +++ b/zokrates_core/src/static_analysis/reducer/constants_reader.rs @@ -0,0 +1,155 @@ +use crate::static_analysis::reducer::ConstantDefinitions; +use crate::typed_absy::{ + folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier, + DeclarationConstant, Expr, FieldElementExpression, Identifier, StructExpression, + StructExpressionInner, StructType, UBitwidth, UExpression, UExpressionInner, +}; +use zokrates_field::Field; + +use std::convert::{TryFrom, TryInto}; + +pub struct ConstantsReader<'a, 'ast, T> { + constants: &'a ConstantDefinitions<'ast, T>, +} + +impl<'a, 'ast, T> ConstantsReader<'a, 'ast, T> { + pub fn with_constants(constants: &'a ConstantDefinitions<'ast, T>) -> Self { + Self { constants } + } +} + +impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + match e { + FieldElementExpression::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + match self.constants.get(&c).cloned() { + Some(v) => v.try_into().unwrap(), + None => FieldElementExpression::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }), + } + } + e => crate::typed_absy::folder::fold_field_expression(self, e), + } + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + match e { + BooleanExpression::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + match self.constants.get(&c).cloned() { + Some(v) => v.try_into().unwrap(), + None => BooleanExpression::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }), + } + } + e => crate::typed_absy::folder::fold_boolean_expression(self, e), + } + } + + fn fold_uint_expression_inner( + &mut self, + ty: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + match e { + UExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + + match self.constants.get(&c).cloned() { + Some(v) => UExpression::try_from(v).unwrap().into_inner(), + None => UExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }), + } + } + e => crate::typed_absy::folder::fold_uint_expression_inner(self, ty, e), + } + } + + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> ArrayExpressionInner<'ast, T> { + match e { + ArrayExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + match self.constants.get(&c).cloned() { + Some(v) => ArrayExpression::try_from(v).unwrap().into_inner(), + None => ArrayExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }), + } + } + e => crate::typed_absy::folder::fold_array_expression_inner(self, ty, e), + } + } + + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + match e { + StructExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }) => { + assert_eq!(version, 0); + match self.constants.get(&c).cloned() { + Some(v) => StructExpression::try_from(v).unwrap().into_inner(), + None => StructExpressionInner::Identifier(Identifier { + id: CoreIdentifier::Constant(c), + version, + }), + } + } + e => crate::typed_absy::folder::fold_struct_expression_inner(self, ty, e), + } + } + + fn fold_declaration_constant( + &mut self, + c: DeclarationConstant<'ast, T>, + ) -> DeclarationConstant<'ast, T> { + match c { + DeclarationConstant::Constant(c) => { + let c = self.fold_canonical_constant_identifier(c); + + match self.constants.get(&c).cloned() { + Some(e) => match UExpression::try_from(e).unwrap().into_inner() { + UExpressionInner::Value(v) => DeclarationConstant::Concrete(v as u32), + _ => unreachable!(), + }, + None => DeclarationConstant::Constant(c), + } + } + c => crate::typed_absy::folder::fold_declaration_constant(self, c), + } + } +} diff --git a/zokrates_core/src/static_analysis/reducer/inline.rs b/zokrates_core/src/static_analysis/reducer/inline.rs index d0814b89e..d7132f81a 100644 --- a/zokrates_core/src/static_analysis/reducer/inline.rs +++ b/zokrates_core/src/static_analysis/reducer/inline.rs @@ -35,8 +35,8 @@ use crate::typed_absy::Identifier; use crate::typed_absy::TypedAssignee; use crate::typed_absy::{ ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr, - Signature, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, Types, - UExpression, UExpressionInner, Variable, + Signature, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedProgram, + TypedStatement, Types, UExpression, UExpressionInner, Variable, }; use zokrates_field::Field; @@ -59,21 +59,18 @@ pub enum InlineError<'ast, T> { fn get_canonical_function<'ast, T: Field>( function_key: DeclarationFunctionKey<'ast, T>, program: &TypedProgram<'ast, T>, -) -> ( - DeclarationFunctionKey<'ast, T>, - TypedFunctionSymbol<'ast, T>, -) { - match program +) -> TypedFunctionSymbolDeclaration<'ast, T> { + let s = program .modules .get(&function_key.module) .unwrap() - .functions - .iter() - .find(|(key, _)| function_key == **key) - .unwrap() - { - (_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program), - (key, s) => (key.clone(), s.clone()), + .functions_iter() + .find(|d| d.key == function_key) + .unwrap(); + + match &s.symbol { + TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), &program), + _ => s.clone(), } } @@ -137,7 +134,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( } }; - let (decl_key, symbol) = get_canonical_function(k.clone(), program); + let decl = get_canonical_function(k.clone(), program); // get an assignment of generics for this call site let assignment: ConcreteGenericsAssignment<'ast> = k @@ -147,14 +144,14 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( InlineError::Generic( k.clone(), ConcreteFunctionKey { - module: decl_key.module.clone(), - id: decl_key.id, + module: decl.key.module.clone(), + id: decl.key.id, signature: inferred_signature.clone(), }, ) })?; - let f = match symbol { + let f = match decl.symbol { TypedFunctionSymbol::Here(f) => Ok(f), TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat( e, @@ -172,7 +169,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)), }; - let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone()); + let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone()); let input_bindings: Vec> = ssa_f .arguments diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 1e4a5d4df..bbc71e0d5 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -11,6 +11,7 @@ // - unroll loops // - inline function calls. This includes applying shallow-ssa on the target function +mod constants_reader; mod inline; mod shallow_ssa; @@ -20,22 +21,20 @@ use crate::typed_absy::types::ConcreteGenericsAssignment; use crate::typed_absy::types::GGenericsAssignment; use crate::typed_absy::CanonicalConstantIdentifier; use crate::typed_absy::Folder; -use crate::typed_absy::UBitwidth; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use crate::typed_absy::{ - ArrayExpression, ArrayExpressionInner, ArrayType, BlockExpression, BooleanExpression, - CoreIdentifier, DeclarationConstant, DeclarationSignature, Expr, FieldElementExpression, + ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, DeclarationSignature, Expr, FunctionCall, FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, - OwnedTypedModuleId, StructExpression, StructExpressionInner, StructType, TypedConstant, - TypedConstantSymbol, TypedExpression, TypedExpressionList, TypedExpressionListInner, - TypedFunction, TypedFunctionSymbol, TypedModule, TypedModuleId, TypedProgram, TypedStatement, - UExpression, UExpressionInner, Variable, + OwnedTypedModuleId, TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, + TypedExpression, TypedExpressionList, TypedExpressionListInner, TypedFunction, + TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedModuleId, TypedProgram, + TypedStatement, TypedSymbolDeclaration, UExpression, UExpressionInner, Variable, }; -use std::convert::{TryFrom, TryInto}; use zokrates_field::Field; +use self::constants_reader::ConstantsReader; use self::shallow_ssa::ShallowTransformer; use crate::static_analysis::propagation::{Constants, Propagator}; @@ -49,16 +48,16 @@ type ConstantDefinitions<'ast, T> = HashMap, TypedExpression<'ast, T>>; // A folder to inline all constant definitions down to a single litteral. Also register them in the state for later use. -struct ConstantCallsInliner<'ast, T> { +struct ConstantsBuilder<'ast, T> { treated: HashSet, constants: ConstantDefinitions<'ast, T>, location: OwnedTypedModuleId, program: TypedProgram<'ast, T>, } -impl<'ast, T> ConstantCallsInliner<'ast, T> { +impl<'ast, T: Field> ConstantsBuilder<'ast, T> { fn with_program(program: TypedProgram<'ast, T>) -> Self { - ConstantCallsInliner { + ConstantsBuilder { constants: ConstantDefinitions::default(), location: program.main.clone(), treated: HashSet::default(), @@ -76,9 +75,27 @@ impl<'ast, T> ConstantCallsInliner<'ast, T> { fn treated(&self, id: &TypedModuleId) -> bool { self.treated.contains(id) } + + fn update_program(&mut self) { + let mut p = TypedProgram { + main: "".into(), + modules: BTreeMap::default(), + }; + std::mem::swap(&mut self.program, &mut p); + let mut reader = ConstantsReader::with_constants(&self.constants); + self.program = reader.fold_program(p); + } + + fn update_symbol_declaration( + &self, + s: TypedSymbolDeclaration<'ast, T>, + ) -> TypedSymbolDeclaration<'ast, T> { + let mut reader = ConstantsReader::with_constants(&self.constants); + reader.fold_symbol_declaration(s) + } } -impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsBuilder<'ast, T> { type Error = Error; fn fold_module_id( @@ -88,6 +105,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet if !self.treated(&id) { let current_m_id = self.change_location(id.clone()); + // I did not find a way to achieve this without cloning the module. Assuming we do not clone: + // to fold the module, we need to consume it, so it gets removed from the modules + // but to inline the calls while folding the module, all modules must be present + // therefore we clone... + // this does not lead to a module being folded more than once, as the first time + // we change location to this module, it's added to the `treated` set let m = self.program.modules.get(&id).cloned().unwrap(); let m = self.fold_module(m)?; self.program.modules.insert(id.clone(), m); @@ -96,201 +119,81 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantCallsInliner<'ast, T> { Ok(id) } - fn fold_field_expression( + fn fold_symbol_declaration( &mut self, - e: FieldElementExpression<'ast, T>, - ) -> Result, Self::Error> { - match e { - FieldElementExpression::Identifier(Identifier { - id: CoreIdentifier::Constant(c), - version, - }) => { - assert_eq!(version, 0); - Ok(self.constants.get(&c).cloned().unwrap().try_into().unwrap()) - } - e => fold_field_expression(self, e), - } - } + s: TypedSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + // before we treat a symbol, propagate the constants into it + let s = self.update_symbol_declaration(s); - fn fold_boolean_expression( - &mut self, - e: BooleanExpression<'ast, T>, - ) -> Result, Self::Error> { - match e { - BooleanExpression::Identifier(Identifier { - id: CoreIdentifier::Constant(c), - version, - }) => { - assert_eq!(version, 0); - Ok(self.constants.get(&c).cloned().unwrap().try_into().unwrap()) - } - e => fold_boolean_expression(self, e), - } + fold_symbol_declaration(self, s) } - fn fold_uint_expression_inner( + fn fold_constant_symbol_declaration( &mut self, - ty: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - UExpressionInner::Identifier(Identifier { - id: CoreIdentifier::Constant(c), - version, - }) => { - assert_eq!(version, 0); - Ok( - UExpression::try_from(self.constants.get(&c).cloned().unwrap()) - .unwrap() - .into_inner(), - ) - } - e => fold_uint_expression_inner(self, ty, e), - } - } - - fn fold_array_expression_inner( - &mut self, - ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - ArrayExpressionInner::Identifier(Identifier { - id: CoreIdentifier::Constant(c), - version, - }) => { - assert_eq!(version, 0); - Ok( - ArrayExpression::try_from(self.constants.get(&c).cloned().unwrap()) - .unwrap() - .into_inner(), - ) - } - e => fold_array_expression_inner(self, ty, e), - } - } + d: TypedConstantSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + let id = self.fold_canonical_constant_identifier(d.id)?; + + match d.symbol { + TypedConstantSymbol::Here(c) => { + let c = self.fold_constant(c)?; + + // wrap this expression in a function + let wrapper = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![c.expression])], + signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]), + }; - fn fold_struct_expression_inner( - &mut self, - ty: &StructType<'ast, T>, - e: StructExpressionInner<'ast, T>, - ) -> Result, Self::Error> { - match e { - StructExpressionInner::Identifier(Identifier { - id: CoreIdentifier::Constant(c), - version, - }) => { - assert_eq!(version, 0); - Ok( - StructExpression::try_from(self.constants.get(&c).cloned().unwrap()) - .unwrap() - .into_inner(), - ) - } - e => fold_struct_expression_inner(self, ty, e), - } - } + let mut inlined_wrapper = reduce_function( + wrapper, + ConcreteGenericsAssignment::default(), + &self.program, + )?; - fn fold_declaration_constant( - &mut self, - c: DeclarationConstant<'ast, T>, - ) -> Result, Self::Error> { - match c { - DeclarationConstant::Constant(c) => { - let c = self.fold_canonical_constant_identifier(c)?; - - if let UExpressionInner::Value(v) = - UExpression::try_from(self.constants.get(&c).cloned().unwrap()) - .unwrap() - .into_inner() + if let TypedStatement::Return(mut expressions) = + inlined_wrapper.statements.pop().unwrap() { - Ok(DeclarationConstant::Concrete(v as u32)) + assert_eq!(expressions.len(), 1); + let constant_expression = expressions.pop().unwrap(); + + use crate::typed_absy::Constant; + if !constant_expression.is_constant() { + return Err(Error::ConstantReduction(id.id.to_string(), id.module)); + }; + + use crate::typed_absy::Typed; + if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>( + c.ty.clone(), + ) + .unwrap() + == constant_expression.get_type() + { + // add to the constant map + self.constants + .insert(id.clone(), constant_expression.clone()); + + // after we reduced a constant, propagate it through the whole program + self.update_program(); + + Ok(TypedConstantSymbolDeclaration { + id, + symbol: TypedConstantSymbol::Here(TypedConstant { + expression: constant_expression, + ty: c.ty, + }), + }) + } else { + Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, c.ty))) + } } else { - unreachable!() + Err(Error::ConstantReduction(id.id.to_string(), id.module)) } } - c => fold_declaration_constant(self, c), + _ => unreachable!("all constants should be local"), } } - - fn fold_module( - &mut self, - m: TypedModule<'ast, T>, - ) -> Result, Self::Error> { - Ok(TypedModule { - constants: m - .constants - .into_iter() - .map(|(id, tc)| match tc { - TypedConstantSymbol::Here(c) => { - let c = self.fold_constant(c)?; - - let ty = c.ty; - - // replace the existing constants in this expression - let constant_replaced_expression = self.fold_expression(c.expression)?; - - // wrap this expression in a function - let wrapper = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - constant_replaced_expression, - ])], - signature: DeclarationSignature::new().outputs(vec![ty.clone()]), - }; - - let mut inlined_wrapper = reduce_function( - wrapper, - ConcreteGenericsAssignment::default(), - &self.program, - )?; - - if let TypedStatement::Return(mut expressions) = - inlined_wrapper.statements.pop().unwrap() - { - assert_eq!(expressions.len(), 1); - let constant_expression = expressions.pop().unwrap(); - - use crate::typed_absy::Constant; - if !constant_expression.is_constant() { - return Err(Error::ConstantReduction( - id.id.to_string(), - id.module, - )); - }; - - use crate::typed_absy::Typed; - if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>(ty.clone()).unwrap() == constant_expression.get_type() { - // add to the constant map - self.constants.insert(id.clone(), constant_expression.clone()); - Ok(( - id, - TypedConstantSymbol::Here(TypedConstant { - expression: constant_expression, - ty, - }), - )) - } else { - Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, ty))) - } - } else { - Err(Error::ConstantReduction(id.id.to_string(), id.module)) - } - } - _ => unreachable!("all constants should be local"), - }) - .collect::, _>>()?, - functions: m - .functions - .into_iter() - .map(|(key, fun)| { - let key = self.fold_declaration_function_key(key)?; - let fun = self.fold_function_symbol(fun)?; - Ok((key, fun)) - }) - .collect::>()?, - }) - } } // An SSA version map, giving access to the latest version number for each identifier @@ -756,20 +659,19 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { pub fn reduce_program(p: TypedProgram) -> Result, Error> { // inline all constants and replace them in the program - let mut constant_calls_inliner = ConstantCallsInliner::with_program(p.clone()); + let mut constant_calls_inliner = ConstantsBuilder::with_program(p.clone()); let p = constant_calls_inliner.fold_program(p)?; // inline starting from main let main_module = p.modules.get(&p.main).unwrap().clone(); - let (main_key, main_function) = main_module - .functions - .iter() - .find(|(k, _)| k.id == "main") + let decl = main_module + .functions_iter() + .find(|d| d.key.id == "main") .unwrap(); - let main_function = match main_function { + let main_function = match &decl.symbol { TypedFunctionSymbol::Here(f) => f.clone(), _ => unreachable!(), }; @@ -783,13 +685,12 @@ pub fn reduce_program(p: TypedProgram) -> Result, E modules: vec![( p.main.clone(), TypedModule { - functions: vec![( - main_key.clone(), - TypedFunctionSymbol::Here(main_function), - )] - .into_iter() - .collect(), - constants: Default::default(), + symbols: vec![TypedSymbolDeclaration::Function( + TypedFunctionSymbolDeclaration { + key: decl.key.clone(), + symbol: TypedFunctionSymbol::Here(main_function), + }, + )], }, )] .into_iter() @@ -988,27 +889,26 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![ - ( + symbols: vec![ + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "foo").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(foo), - ), - ( + ) + .into(), + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(main), - ), - ] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into(), + ], }, )] .into_iter() @@ -1064,17 +964,15 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![( + symbols: vec![TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into()], }, )] .into_iter() @@ -1191,24 +1089,23 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![ - ( + symbols: vec![ + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), - ), - ( + ) + .into(), + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(main), - ), - ] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into(), + ], }, )] .into_iter() @@ -1283,17 +1180,15 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![( + symbols: vec![TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into()], }, )] .into_iter() @@ -1419,24 +1314,23 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![ - ( + symbols: vec![ + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), - ), - ( + ) + .into(), + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(main), - ), - ] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into(), + ], }, )] .into_iter() @@ -1511,17 +1405,15 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![( + symbols: vec![TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new() .inputs(vec![DeclarationType::FieldElement]) .outputs(vec![DeclarationType::FieldElement]), ), TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into()], }, )] .into_iter() @@ -1677,25 +1569,25 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![ - ( + symbols: vec![ + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "bar") .signature(bar_signature.clone()), TypedFunctionSymbol::Here(bar), - ), - ( + ) + .into(), + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), - ), - ( + ) + .into(), + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main"), TypedFunctionSymbol::Here(main), - ), - ] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into(), + ], }, )] .into_iter() @@ -1737,14 +1629,12 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![( + symbols: vec![TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main") .signature(DeclarationSignature::new()), TypedFunctionSymbol::Here(expected_main), - )] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into()], }, )] .into_iter() @@ -1818,22 +1708,21 @@ mod tests { modules: vec![( "main".into(), TypedModule { - functions: vec![ - ( + symbols: vec![ + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "foo") .signature(foo_signature.clone()), TypedFunctionSymbol::Here(foo), - ), - ( + ) + .into(), + TypedFunctionSymbolDeclaration::new( DeclarationFunctionKey::with_location("main", "main").signature( DeclarationSignature::new().inputs(vec![]).outputs(vec![]), ), TypedFunctionSymbol::Here(main), - ), - ] - .into_iter() - .collect(), - constants: Default::default(), + ) + .into(), + ], }, )] .into_iter() diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 6256e10ab..65b53bba6 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -35,15 +35,15 @@ mod tests { }; use crate::typed_absy::{ parameter::DeclarationParameter, variable::DeclarationVariable, ConcreteType, - TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, + TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, + TypedProgram, }; use std::collections::BTreeMap; use zokrates_field::Bn128Field; #[test] fn generate_abi_from_typed_ast() { - let mut functions = BTreeMap::new(); - functions.insert( + let symbols = vec![TypedFunctionSymbolDeclaration::new( ConcreteFunctionKey::with_location("main", "main").into(), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![ @@ -62,16 +62,11 @@ mod tests { .outputs(vec![ConcreteType::FieldElement]) .into(), }), - ); + ) + .into()]; let mut modules = BTreeMap::new(); - modules.insert( - "main".into(), - TypedModule { - functions, - constants: Default::default(), - }, - ); + modules.insert("main".into(), TypedModule { symbols }); let typed_ast: TypedProgram = TypedProgram { main: "main".into(), diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index fae489b67..b9a3d6bff 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -47,6 +47,27 @@ pub trait Folder<'ast, T: Field>: Sized { fold_module(self, m) } + fn fold_symbol_declaration( + &mut self, + s: TypedSymbolDeclaration<'ast, T>, + ) -> TypedSymbolDeclaration<'ast, T> { + fold_symbol_declaration(self, s) + } + + fn fold_function_symbol_declaration( + &mut self, + s: TypedFunctionSymbolDeclaration<'ast, T>, + ) -> TypedFunctionSymbolDeclaration<'ast, T> { + fold_function_symbol_declaration(self, s) + } + + fn fold_constant_symbol_declaration( + &mut self, + s: TypedConstantSymbolDeclaration<'ast, T>, + ) -> TypedConstantSymbolDeclaration<'ast, T> { + fold_constant_symbol_declaration(self, s) + } + fn fold_constant(&mut self, c: TypedConstant<'ast, T>) -> TypedConstant<'ast, T> { fold_constant(self, c) } @@ -383,29 +404,48 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( m: TypedModule<'ast, T>, ) -> TypedModule<'ast, T> { TypedModule { - constants: m - .constants - .into_iter() - .map(|(id, tc)| { - ( - f.fold_canonical_constant_identifier(id), - f.fold_constant_symbol(tc), - ) - }) - .collect(), - functions: m - .functions + symbols: m + .symbols .into_iter() - .map(|(key, fun)| { - ( - f.fold_declaration_function_key(key), - f.fold_function_symbol(fun), - ) - }) + .map(|s| f.fold_symbol_declaration(s)) .collect(), } } +pub fn fold_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + d: TypedSymbolDeclaration<'ast, T>, +) -> TypedSymbolDeclaration<'ast, T> { + match d { + TypedSymbolDeclaration::Function(d) => { + TypedSymbolDeclaration::Function(f.fold_function_symbol_declaration(d)) + } + TypedSymbolDeclaration::Constant(d) => { + TypedSymbolDeclaration::Constant(f.fold_constant_symbol_declaration(d)) + } + } +} + +pub fn fold_function_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + d: TypedFunctionSymbolDeclaration<'ast, T>, +) -> TypedFunctionSymbolDeclaration<'ast, T> { + TypedFunctionSymbolDeclaration { + key: f.fold_declaration_function_key(d.key), + symbol: f.fold_function_symbol(d.symbol), + } +} + +pub fn fold_constant_symbol_declaration<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + d: TypedConstantSymbolDeclaration<'ast, T>, +) -> TypedConstantSymbolDeclaration<'ast, T> { + TypedConstantSymbolDeclaration { + id: f.fold_canonical_constant_identifier(d.id), + symbol: f.fold_constant_symbol(d.symbol), + } +} + pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: TypedStatement<'ast, T>, @@ -977,7 +1017,7 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>( } } -fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>( +pub fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, c: DeclarationConstant<'ast, T>, ) -> DeclarationConstant<'ast, T> { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index e2f58456a..f9d40185e 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -91,12 +91,11 @@ impl<'ast, T> TypedProgram<'ast, T> { impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn abi(&self) -> Abi { - let main = self.modules[&self.main] - .functions - .iter() - .find(|(id, _)| id.id == "main") + let main = &self.modules[&self.main] + .functions_iter() + .find(|s| s.key.id == "main") .unwrap() - .1; + .symbol; let main = match main { TypedFunctionSymbol::Here(main) => main, _ => unreachable!(), @@ -163,13 +162,84 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> { } } +#[derive(PartialEq, Debug, Clone)] +pub struct TypedFunctionSymbolDeclaration<'ast, T> { + pub key: DeclarationFunctionKey<'ast, T>, + pub symbol: TypedFunctionSymbol<'ast, T>, +} + +impl<'ast, T> TypedFunctionSymbolDeclaration<'ast, T> { + pub fn new(key: DeclarationFunctionKey<'ast, T>, symbol: TypedFunctionSymbol<'ast, T>) -> Self { + TypedFunctionSymbolDeclaration { key, symbol } + } +} + +#[derive(PartialEq, Debug, Clone)] +pub struct TypedConstantSymbolDeclaration<'ast, T> { + pub id: CanonicalConstantIdentifier<'ast>, + pub symbol: TypedConstantSymbol<'ast, T>, +} + +impl<'ast, T> TypedConstantSymbolDeclaration<'ast, T> { + pub fn new( + id: CanonicalConstantIdentifier<'ast>, + symbol: TypedConstantSymbol<'ast, T>, + ) -> Self { + TypedConstantSymbolDeclaration { id, symbol } + } +} + +#[derive(PartialEq, Debug, Clone)] +pub enum TypedSymbolDeclaration<'ast, T> { + Function(TypedFunctionSymbolDeclaration<'ast, T>), + Constant(TypedConstantSymbolDeclaration<'ast, T>), +} + +impl<'ast, T> From> for TypedSymbolDeclaration<'ast, T> { + fn from(d: TypedFunctionSymbolDeclaration<'ast, T>) -> Self { + Self::Function(d) + } +} + +impl<'ast, T> From> for TypedSymbolDeclaration<'ast, T> { + fn from(d: TypedConstantSymbolDeclaration<'ast, T>) -> Self { + Self::Constant(d) + } +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedSymbolDeclaration<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TypedSymbolDeclaration::Function(fun) => write!(f, "{}", fun), + TypedSymbolDeclaration::Constant(c) => write!(f, "{}", c), + } + } +} + +pub type TypedSymbolDeclarations<'ast, T> = Vec>; + /// A typed module as a collection of functions. Types have been resolved during semantic checking. #[derive(PartialEq, Debug, Clone)] pub struct TypedModule<'ast, T> { - /// Functions of the module - pub functions: TypedFunctionSymbols<'ast, T>, - /// Constants defined in module - pub constants: TypedConstantSymbols<'ast, T>, + pub symbols: TypedSymbolDeclarations<'ast, T>, +} + +impl<'ast, T> TypedModule<'ast, T> { + pub fn functions_iter(&self) -> impl Iterator> { + self.symbols.iter().filter_map(|s| match s { + TypedSymbolDeclaration::Function(d) => Some(d), + _ => None, + }) + } + + pub fn into_functions_iter( + self, + ) -> impl Iterator> { + self.symbols.into_iter().filter_map(|s| match s { + TypedSymbolDeclaration::Function(d) => Some(d), + _ => None, + }) + } } #[derive(Clone, PartialEq, Debug)] @@ -189,50 +259,65 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { TypedFunctionSymbol::There(key) => modules .get(&key.module) .unwrap() - .functions - .get(key) + .functions_iter() + .find(|d| d.key == *key) .unwrap() + .symbol .signature(&modules), TypedFunctionSymbol::Flat(flat_fun) => flat_fun.typed_signature(), } } } +impl<'ast, T: fmt::Display> fmt::Display for TypedConstantSymbolDeclaration<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.symbol { + TypedConstantSymbol::Here(ref tc) => { + write!(f, "const {} {} = {}", tc.ty, self.id, tc.expression) + } + TypedConstantSymbol::There(ref imported_id) => { + write!( + f, + "from \"{}\" import {} as {}", + imported_id.module.display(), + imported_id.id, + self.id + ) + } + } + } +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedFunctionSymbolDeclaration<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.symbol { + TypedFunctionSymbol::Here(ref function) => write!(f, "def {}{}", self.key.id, function), + TypedFunctionSymbol::There(ref fun_key) => write!( + f, + "from \"{}\" import {} as {} // with signature {}", + fun_key.module.display(), + fun_key.id, + self.key.id, + self.key.signature + ), + TypedFunctionSymbol::Flat(ref flat_fun) => { + write!( + f, + "def {}{}:\n\t// hidden", + self.key.id, + flat_fun.typed_signature::() + ) + } + } + } +} + impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = self - .constants + .symbols .iter() - .map(|(id, symbol)| match symbol { - TypedConstantSymbol::Here(ref tc) => { - format!("const {} {} = {}", tc.ty, id.id, tc.expression) - } - TypedConstantSymbol::There(ref imported_id) => { - format!( - "from \"{}\" import {} as {}", - imported_id.module.display(), - imported_id.id, - id.id - ) - } - }) - .chain(self.functions.iter().map(|(key, symbol)| match symbol { - TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), - TypedFunctionSymbol::There(ref fun_key) => format!( - "from \"{}\" import {} as {} // with signature {}", - fun_key.module.display(), - fun_key.id, - key.id, - key.signature - ), - TypedFunctionSymbol::Flat(ref flat_fun) => { - format!( - "def {}{}:\n\t// hidden", - key.id, - flat_fun.typed_signature::() - ) - } - })) + .map(|s| format!("{}", s)) .collect::>(); write!(f, "{}", res.join("\n")) diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index fb38b2f70..f952e28bc 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -55,6 +55,27 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_module(self, m) } + fn fold_symbol_declaration( + &mut self, + s: TypedSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + fold_symbol_declaration(self, s) + } + + fn fold_function_symbol_declaration( + &mut self, + s: TypedFunctionSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + fold_function_symbol_declaration(self, s) + } + + fn fold_constant_symbol_declaration( + &mut self, + s: TypedConstantSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + fold_constant_symbol_declaration(self, s) + } + fn fold_constant( &mut self, c: TypedConstant<'ast, T>, @@ -1147,24 +1168,49 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( } } +pub fn fold_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + d: TypedSymbolDeclaration<'ast, T>, +) -> Result, F::Error> { + Ok(match d { + TypedSymbolDeclaration::Function(d) => { + TypedSymbolDeclaration::Function(f.fold_function_symbol_declaration(d)?) + } + TypedSymbolDeclaration::Constant(d) => { + TypedSymbolDeclaration::Constant(f.fold_constant_symbol_declaration(d)?) + } + }) +} + +pub fn fold_function_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + d: TypedFunctionSymbolDeclaration<'ast, T>, +) -> Result, F::Error> { + Ok(TypedFunctionSymbolDeclaration { + key: f.fold_declaration_function_key(d.key)?, + symbol: f.fold_function_symbol(d.symbol)?, + }) +} + +pub fn fold_constant_symbol_declaration<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + d: TypedConstantSymbolDeclaration<'ast, T>, +) -> Result, F::Error> { + Ok(TypedConstantSymbolDeclaration { + id: f.fold_canonical_constant_identifier(d.id)?, + symbol: f.fold_constant_symbol(d.symbol)?, + }) +} + pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, m: TypedModule<'ast, T>, ) -> Result, F::Error> { Ok(TypedModule { - constants: m - .constants - .into_iter() - .map(|(key, tc)| f.fold_constant_symbol(tc).map(|tc| (key, tc))) - .collect::>()?, - functions: m - .functions + symbols: m + .symbols .into_iter() - .map(|(key, fun)| { - let key = f.fold_declaration_function_key(key)?; - let fun = f.fold_function_symbol(fun)?; - Ok((key, fun)) - }) + .map(|s| f.fold_symbol_declaration(s)) .collect::>()?, }) } From dc84b29ee877f8ad29ca7254c9b95b083cc8e93f Mon Sep 17 00:00:00 2001 From: dark64 Date: Sun, 12 Sep 2021 23:59:36 +0200 Subject: [PATCH 42/78] remove cast to usize causing wrong values in wasm environment --- zokrates_core/src/flatten/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index ca916636b..79d3948f9 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1465,7 +1465,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let res = match expr.into_inner() { UExpressionInner::Value(x) => { - FlatUExpression::with_field(FlatExpression::Number(T::from(x as usize))) + FlatUExpression::with_field(FlatExpression::Number(T::from(x))) } // force to be a field element UExpressionInner::Identifier(x) => { let field = FlatExpression::Identifier(*self.layout.get(&x).unwrap()); From 5843cb2377c4abff6b891b2b2c6c770ae9b494fc Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 13 Sep 2021 00:18:05 +0200 Subject: [PATCH 43/78] add changelog --- changelogs/unreleased/998-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/998-dark64 diff --git a/changelogs/unreleased/998-dark64 b/changelogs/unreleased/998-dark64 new file mode 100644 index 000000000..d784d85c1 --- /dev/null +++ b/changelogs/unreleased/998-dark64 @@ -0,0 +1 @@ +Fix invalid cast to `usize` which caused wrong values in 32-bit environments \ No newline at end of file From c1155888c20c04033958d783af8e3c2f4e793bfb Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 13 Sep 2021 15:54:12 +0200 Subject: [PATCH 44/78] update before and after folding --- zokrates_core/src/static_analysis/reducer/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index bbc71e0d5..1a3c8549b 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -126,7 +126,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsBuilder<'ast, T> { // before we treat a symbol, propagate the constants into it let s = self.update_symbol_declaration(s); - fold_symbol_declaration(self, s) + let s = fold_symbol_declaration(self, s)?; + + Ok(self.update_symbol_declaration(s)) } fn fold_constant_symbol_declaration( From dbf0574d3b64ea0b7b9a2f4af973a53f310793b7 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 13 Sep 2021 17:29:35 +0200 Subject: [PATCH 45/78] add comments --- zokrates_core/src/static_analysis/reducer/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 1a3c8549b..3d6967438 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -123,11 +123,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsBuilder<'ast, T> { &mut self, s: TypedSymbolDeclaration<'ast, T>, ) -> Result, Self::Error> { - // before we treat a symbol, propagate the constants into it + // before we treat the symbol, propagate the constants into it, as may be using constants defined earlier in this module. let s = self.update_symbol_declaration(s); let s = fold_symbol_declaration(self, s)?; + // after we treat the symbol, propagate again, as treating this symbol may have triggered checking another module, resolving new constants which this symbol may be using. Ok(self.update_symbol_declaration(s)) } From ec01008165d2bb8d80291a39800be756f433cd5b Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 13 Sep 2021 18:49:13 +0200 Subject: [PATCH 46/78] fix invalid bitwidth set on select index --- zokrates_core/src/static_analysis/zir_propagation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/zir_propagation.rs b/zokrates_core/src/static_analysis/zir_propagation.rs index 9b09a59ee..c3028a030 100644 --- a/zokrates_core/src/static_analysis/zir_propagation.rs +++ b/zokrates_core/src/static_analysis/zir_propagation.rs @@ -462,7 +462,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { .cloned() .ok_or_else(|| Error::OutOfBounds(v, e.len() as u128)) .map(|e| e.into_inner()), - i => Ok(UExpressionInner::Select(e, box i.annotate(bitwidth))), + i => Ok(UExpressionInner::Select(e, box i.annotate(UBitwidth::B32))), } } UExpressionInner::Add(box e1, box e2) => { From 6d382fdd2c1e8175372ccc0bf87bd3eb94656cc6 Mon Sep 17 00:00:00 2001 From: Christoph Michelbach Date: Wed, 15 Sep 2021 10:53:33 +0200 Subject: [PATCH 47/78] Reduce the size of the verifier for scheme G16. --- .../src/proof_system/scheme/groth16.rs | 10 +- zokrates_core/src/proof_system/solidity.rs | 144 ++++++++++++++++++ 2 files changed, 149 insertions(+), 5 deletions(-) diff --git a/zokrates_core/src/proof_system/scheme/groth16.rs b/zokrates_core/src/proof_system/scheme/groth16.rs index b84b762bb..b473e5475 100644 --- a/zokrates_core/src/proof_system/scheme/groth16.rs +++ b/zokrates_core/src/proof_system/scheme/groth16.rs @@ -1,5 +1,5 @@ use crate::proof_system::scheme::{NonUniversalScheme, Scheme}; -use crate::proof_system::solidity::{SOLIDITY_G2_ADDITION_LIB, SOLIDITY_PAIRING_LIB}; +use crate::proof_system::solidity::{SOLIDITY_PAIRING_LIB_SANS_BN256G2}; use crate::proof_system::{G1Affine, G2Affine, SolidityCompatibleField, SolidityCompatibleScheme}; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -32,9 +32,9 @@ impl NonUniversalScheme for G16 {} impl SolidityCompatibleScheme for G16 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { - let (mut template_text, solidity_pairing_lib) = ( + let (mut template_text, solidity_pairing_lib_sans_bn256g2) = ( String::from(CONTRACT_TEMPLATE), - String::from(SOLIDITY_PAIRING_LIB), + String::from(SOLIDITY_PAIRING_LIB_SANS_BN256G2), ); let vk_regex = Regex::new(r#"(<%vk_[^i%]*%>)"#).unwrap(); @@ -123,8 +123,8 @@ impl SolidityCompatibleScheme for G16 { template_text = re.replace_all(&template_text, "uint256($v)").to_string(); format!( - "{}{}{}", - SOLIDITY_G2_ADDITION_LIB, solidity_pairing_lib, template_text + "{}{}", + solidity_pairing_lib_sans_bn256g2, template_text ) } } diff --git a/zokrates_core/src/proof_system/solidity.rs b/zokrates_core/src/proof_system/solidity.rs index 9c1309e8d..d9380235c 100644 --- a/zokrates_core/src/proof_system/solidity.rs +++ b/zokrates_core/src/proof_system/solidity.rs @@ -553,3 +553,147 @@ library Pairing { } } "#; + +pub const SOLIDITY_PAIRING_LIB_SANS_BN256G2: &str = r#"// This file is MIT Licensed. +// +// Copyright 2017 Christian Reitwiessner +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +pragma solidity ^0.8.0; +library Pairing { + struct G1Point { + uint X; + uint Y; + } + // Encoding of field elements is: X[0] * z + X[1] + struct G2Point { + uint[2] X; + uint[2] Y; + } + /// @return the generator of G1 + function P1() pure internal returns (G1Point memory) { + return G1Point(1, 2); + } + /// @return the generator of G2 + function P2() pure internal returns (G2Point memory) { + return G2Point( + [10857046999023057135944570762232829481370756359578518086990519993285655852781, + 11559732032986387107991004021392285783925812861821192530917403151452391805634], + [8495653923123431417604973247489272438418190587263600148770280649306958101930, + 4082367875863433681332203403145435568316851327593401208105741076214120093531] + ); + } + /// @return the negation of p, i.e. p.addition(p.negate()) should be zero. + function negate(G1Point memory p) pure internal returns (G1Point memory) { + // The prime q in the base field F_q for G1 + uint q = 21888242871839275222246405745257275088696311157297823662689037894645226208583; + if (p.X == 0 && p.Y == 0) + return G1Point(0, 0); + return G1Point(p.X, q - (p.Y % q)); + } + /// @return r the sum of two points of G1 + function addition(G1Point memory p1, G1Point memory p2) internal view returns (G1Point memory r) { + uint[4] memory input; + input[0] = p1.X; + input[1] = p1.Y; + input[2] = p2.X; + input[3] = p2.Y; + bool success; + assembly { + success := staticcall(sub(gas(), 2000), 6, input, 0xc0, r, 0x60) + // Use "invalid" to make gas estimation work + switch success case 0 { invalid() } + } + require(success); + } + /// @return r the product of a point on G1 and a scalar, i.e. + /// p == p.scalar_mul(1) and p.addition(p) == p.scalar_mul(2) for all points p. + function scalar_mul(G1Point memory p, uint s) internal view returns (G1Point memory r) { + uint[3] memory input; + input[0] = p.X; + input[1] = p.Y; + input[2] = s; + bool success; + assembly { + success := staticcall(sub(gas(), 2000), 7, input, 0x80, r, 0x60) + // Use "invalid" to make gas estimation work + switch success case 0 { invalid() } + } + require (success); + } + /// @return the result of computing the pairing check + /// e(p1[0], p2[0]) * .... * e(p1[n], p2[n]) == 1 + /// For example pairing([P1(), P1().negate()], [P2(), P2()]) should + /// return true. + function pairing(G1Point[] memory p1, G2Point[] memory p2) internal view returns (bool) { + require(p1.length == p2.length); + uint elements = p1.length; + uint inputSize = elements * 6; + uint[] memory input = new uint[](inputSize); + for (uint i = 0; i < elements; i++) + { + input[i * 6 + 0] = p1[i].X; + input[i * 6 + 1] = p1[i].Y; + input[i * 6 + 2] = p2[i].X[1]; + input[i * 6 + 3] = p2[i].X[0]; + input[i * 6 + 4] = p2[i].Y[1]; + input[i * 6 + 5] = p2[i].Y[0]; + } + uint[1] memory out; + bool success; + assembly { + success := staticcall(sub(gas(), 2000), 8, add(input, 0x20), mul(inputSize, 0x20), out, 0x20) + // Use "invalid" to make gas estimation work + switch success case 0 { invalid() } + } + require(success); + return out[0] != 0; + } + /// Convenience method for a pairing check for two pairs. + function pairingProd2(G1Point memory a1, G2Point memory a2, G1Point memory b1, G2Point memory b2) internal view returns (bool) { + G1Point[] memory p1 = new G1Point[](2); + G2Point[] memory p2 = new G2Point[](2); + p1[0] = a1; + p1[1] = b1; + p2[0] = a2; + p2[1] = b2; + return pairing(p1, p2); + } + /// Convenience method for a pairing check for three pairs. + function pairingProd3( + G1Point memory a1, G2Point memory a2, + G1Point memory b1, G2Point memory b2, + G1Point memory c1, G2Point memory c2 + ) internal view returns (bool) { + G1Point[] memory p1 = new G1Point[](3); + G2Point[] memory p2 = new G2Point[](3); + p1[0] = a1; + p1[1] = b1; + p1[2] = c1; + p2[0] = a2; + p2[1] = b2; + p2[2] = c2; + return pairing(p1, p2); + } + /// Convenience method for a pairing check for four pairs. + function pairingProd4( + G1Point memory a1, G2Point memory a2, + G1Point memory b1, G2Point memory b2, + G1Point memory c1, G2Point memory c2, + G1Point memory d1, G2Point memory d2 + ) internal view returns (bool) { + G1Point[] memory p1 = new G1Point[](4); + G2Point[] memory p2 = new G2Point[](4); + p1[0] = a1; + p1[1] = b1; + p1[2] = c1; + p1[3] = d1; + p2[0] = a2; + p2[1] = b2; + p2[2] = c2; + p2[3] = d2; + return pairing(p1, p2); + } +} +"#; From b066a663314150828e973390c60e333698c617ae Mon Sep 17 00:00:00 2001 From: Christoph Michelbach Date: Thu, 16 Sep 2021 16:38:42 +0200 Subject: [PATCH 48/78] Format code. --- zokrates_core/src/proof_system/scheme/groth16.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/zokrates_core/src/proof_system/scheme/groth16.rs b/zokrates_core/src/proof_system/scheme/groth16.rs index b473e5475..a998c953c 100644 --- a/zokrates_core/src/proof_system/scheme/groth16.rs +++ b/zokrates_core/src/proof_system/scheme/groth16.rs @@ -1,5 +1,5 @@ use crate::proof_system::scheme::{NonUniversalScheme, Scheme}; -use crate::proof_system::solidity::{SOLIDITY_PAIRING_LIB_SANS_BN256G2}; +use crate::proof_system::solidity::SOLIDITY_PAIRING_LIB_SANS_BN256G2; use crate::proof_system::{G1Affine, G2Affine, SolidityCompatibleField, SolidityCompatibleScheme}; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -122,10 +122,7 @@ impl SolidityCompatibleScheme for G16 { let re = Regex::new(r"(?P0[xX][0-9a-fA-F]{64})").unwrap(); template_text = re.replace_all(&template_text, "uint256($v)").to_string(); - format!( - "{}{}", - solidity_pairing_lib_sans_bn256g2, template_text - ) + format!("{}{}", solidity_pairing_lib_sans_bn256g2, template_text) } } From 5d91fa582835ee801d102c9ac5806374f6d641ae Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 17 Sep 2021 19:23:16 +0300 Subject: [PATCH 49/78] clean --- zokrates_core/src/semantics.rs | 68 ++++--- .../reducer/constants_reader.rs | 30 ++- .../reducer/constants_writer.rs | 163 ++++++++++++++++ .../src/static_analysis/reducer/mod.rs | 183 ++---------------- 4 files changed, 238 insertions(+), 206 deletions(-) create mode 100644 zokrates_core/src/static_analysis/reducer/constants_writer.rs diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 44eeb7100..d4b224c33 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -268,7 +268,13 @@ pub struct ScopedIdentifier<'ast> { impl<'ast> ScopedIdentifier<'ast> { fn is_constant(&self) -> bool { - self.level == 0 + let res = self.level == 0; + + // currently this is encoded twice: in the level and in the identifier itself. + // we add a sanity check to make sure the two agree + assert_eq!(res, matches!(self.id, CoreIdentifier::Constant(..))); + + res } } @@ -577,15 +583,16 @@ impl<'ast, T: Field> Checker<'ast, T> { .in_file(module_id), ), true => { - symbols.push(TypedSymbolDeclaration::Constant( - TypedConstantSymbolDeclaration { - id: CanonicalConstantIdentifier::new( + symbols.push( + TypedConstantSymbolDeclaration::new( + CanonicalConstantIdentifier::new( declaration.id, module_id.into(), ), - symbol: TypedConstantSymbol::Here(c.clone()), - }, - )); + TypedConstantSymbol::Here(c.clone()), + ) + .into(), + ); self.insert_into_scope(Variable::with_id_and_type( CoreIdentifier::Constant(CanonicalConstantIdentifier::new( declaration.id, @@ -633,16 +640,17 @@ impl<'ast, T: Field> Checker<'ast, T> { ) .signature(funct.signature.clone()), ); - symbols.push(TypedSymbolDeclaration::Function( - TypedFunctionSymbolDeclaration { - key: DeclarationFunctionKey::with_location( + symbols.push( + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location( module_id.to_path_buf(), declaration.id, ) .signature(funct.signature.clone()), - symbol: TypedFunctionSymbol::Here(funct), - }, - )); + TypedFunctionSymbol::Here(funct), + ) + .into(), + ); } Err(e) => { errors.extend(e.into_iter().map(|inner| inner.in_file(module_id))); @@ -740,10 +748,10 @@ impl<'ast, T: Field> Checker<'ast, T> { let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id); let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into()); - symbols.push(TypedSymbolDeclaration::Constant(TypedConstantSymbolDeclaration { - id: id.clone(), - symbol: TypedConstantSymbol::There(imported_id) - })); + symbols.push(TypedConstantSymbolDeclaration::new( + id.clone(), + TypedConstantSymbol::There(imported_id) + ).into()); self.insert_into_scope(Variable::with_id_and_type(CoreIdentifier::Constant(CanonicalConstantIdentifier::new( declaration.id, module_id.into(), @@ -787,11 +795,11 @@ impl<'ast, T: Field> Checker<'ast, T> { self.functions.insert(local_key.clone()); symbols.push( - TypedSymbolDeclaration::Function(TypedFunctionSymbolDeclaration { - key: local_key, - symbol: TypedFunctionSymbol::There(candidate, + TypedFunctionSymbolDeclaration::new( + local_key, + TypedFunctionSymbol::There(candidate, ), - }) + ).into() ); } } @@ -823,16 +831,17 @@ impl<'ast, T: Field> Checker<'ast, T> { DeclarationFunctionKey::with_location(module_id.to_path_buf(), declaration.id) .signature(funct.typed_signature()), ); - symbols.push(TypedSymbolDeclaration::Function( - TypedFunctionSymbolDeclaration { - key: DeclarationFunctionKey::with_location( + symbols.push( + TypedFunctionSymbolDeclaration::new( + DeclarationFunctionKey::with_location( module_id.to_path_buf(), declaration.id, ) .signature(funct.typed_signature()), - symbol: TypedFunctionSymbol::Flat(funct), - }, - )); + TypedFunctionSymbol::Flat(funct), + ) + .into(), + ); } _ => unreachable!(), }; @@ -896,9 +905,8 @@ impl<'ast, T: Field> Checker<'ast, T> { fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { match module - .symbols - .iter() - .filter(|s| matches!(s, TypedSymbolDeclaration::Function(d) if d.key.id == "main")) + .functions_iter() + .filter(|d| d.key.id == "main") .count() { 1 => Ok(()), diff --git a/zokrates_core/src/static_analysis/reducer/constants_reader.rs b/zokrates_core/src/static_analysis/reducer/constants_reader.rs index 9f39140a1..8fb1d4a1e 100644 --- a/zokrates_core/src/static_analysis/reducer/constants_reader.rs +++ b/zokrates_core/src/static_analysis/reducer/constants_reader.rs @@ -1,8 +1,11 @@ +// given a (partial) map of values for program constants, replace where applicable constants by their value + use crate::static_analysis::reducer::ConstantDefinitions; use crate::typed_absy::{ folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Identifier, StructExpression, - StructExpressionInner, StructType, UBitwidth, UExpression, UExpressionInner, + StructExpressionInner, StructType, TypedProgram, TypedSymbolDeclaration, UBitwidth, + UExpression, UExpressionInner, }; use zokrates_field::Field; @@ -12,10 +15,21 @@ pub struct ConstantsReader<'a, 'ast, T> { constants: &'a ConstantDefinitions<'ast, T>, } -impl<'a, 'ast, T> ConstantsReader<'a, 'ast, T> { +impl<'a, 'ast, T: Field> ConstantsReader<'a, 'ast, T> { pub fn with_constants(constants: &'a ConstantDefinitions<'ast, T>) -> Self { Self { constants } } + + pub fn read_into_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + self.fold_program(p) + } + + pub fn read_into_symbol_declaration( + &mut self, + d: TypedSymbolDeclaration<'ast, T>, + ) -> TypedSymbolDeclaration<'ast, T> { + self.fold_symbol_declaration(d) + } } impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { @@ -37,7 +51,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { }), } } - e => crate::typed_absy::folder::fold_field_expression(self, e), + e => fold_field_expression(self, e), } } @@ -59,7 +73,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { }), } } - e => crate::typed_absy::folder::fold_boolean_expression(self, e), + e => fold_boolean_expression(self, e), } } @@ -83,7 +97,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { }), } } - e => crate::typed_absy::folder::fold_uint_expression_inner(self, ty, e), + e => fold_uint_expression_inner(self, ty, e), } } @@ -106,7 +120,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { }), } } - e => crate::typed_absy::folder::fold_array_expression_inner(self, ty, e), + e => fold_array_expression_inner(self, ty, e), } } @@ -129,7 +143,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { }), } } - e => crate::typed_absy::folder::fold_struct_expression_inner(self, ty, e), + e => fold_struct_expression_inner(self, ty, e), } } @@ -149,7 +163,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { None => DeclarationConstant::Constant(c), } } - c => crate::typed_absy::folder::fold_declaration_constant(self, c), + c => fold_declaration_constant(self, c), } } } diff --git a/zokrates_core/src/static_analysis/reducer/constants_writer.rs b/zokrates_core/src/static_analysis/reducer/constants_writer.rs new file mode 100644 index 000000000..dfec62236 --- /dev/null +++ b/zokrates_core/src/static_analysis/reducer/constants_writer.rs @@ -0,0 +1,163 @@ +// A folder to inline all constant definitions down to a single litteral and register them in the state for later use. + +use crate::static_analysis::reducer::{ + constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error, +}; +use crate::typed_absy::{ + result_folder::*, types::ConcreteGenericsAssignment, OwnedTypedModuleId, TypedConstant, + TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, + TypedSymbolDeclaration, UExpression, +}; +use std::collections::{BTreeMap, HashSet}; +use zokrates_field::Field; + +pub struct ConstantsWriter<'ast, T> { + treated: HashSet, + constants: ConstantDefinitions<'ast, T>, + location: OwnedTypedModuleId, + program: TypedProgram<'ast, T>, +} + +impl<'ast, T: Field> ConstantsWriter<'ast, T> { + pub fn with_program(program: TypedProgram<'ast, T>) -> Self { + ConstantsWriter { + constants: ConstantDefinitions::default(), + location: program.main.clone(), + treated: HashSet::default(), + program, + } + } + + fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { + let prev = self.location.clone(); + self.location = location; + self.treated.insert(self.location.clone()); + prev + } + + fn treated(&self, id: &TypedModuleId) -> bool { + self.treated.contains(id) + } + + fn update_program(&mut self) { + let mut p = TypedProgram { + main: "".into(), + modules: BTreeMap::default(), + }; + std::mem::swap(&mut self.program, &mut p); + self.program = ConstantsReader::with_constants(&self.constants).read_into_program(p); + } + + fn update_symbol_declaration( + &self, + d: TypedSymbolDeclaration<'ast, T>, + ) -> TypedSymbolDeclaration<'ast, T> { + ConstantsReader::with_constants(&self.constants).read_into_symbol_declaration(d) + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { + type Error = Error; + + fn fold_module_id( + &mut self, + id: OwnedTypedModuleId, + ) -> Result { + // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet + if !self.treated(&id) { + let current_m_id = self.change_location(id.clone()); + // I did not find a way to achieve this without cloning the module. Assuming we do not clone: + // to fold the module, we need to consume it, so it gets removed from the modules + // but to inline the calls while folding the module, all modules must be present + // therefore we clone... + // this does not lead to a module being folded more than once, as the first time + // we change location to this module, it's added to the `treated` set + let m = self.program.modules.get(&id).cloned().unwrap(); + let m = self.fold_module(m)?; + self.program.modules.insert(id.clone(), m); + self.change_location(current_m_id); + } + Ok(id) + } + + fn fold_symbol_declaration( + &mut self, + s: TypedSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + // before we treat the symbol, propagate the constants into it, as may be using constants defined earlier in this module. + let s = self.update_symbol_declaration(s); + + let s = fold_symbol_declaration(self, s)?; + + // after we treat the symbol, propagate again, as treating this symbol may have triggered checking another module, resolving new constants which this symbol may be using. + Ok(self.update_symbol_declaration(s)) + } + + fn fold_constant_symbol_declaration( + &mut self, + d: TypedConstantSymbolDeclaration<'ast, T>, + ) -> Result, Self::Error> { + let id = self.fold_canonical_constant_identifier(d.id)?; + + match d.symbol { + TypedConstantSymbol::Here(c) => { + let c = self.fold_constant(c)?; + + use crate::typed_absy::{DeclarationSignature, TypedFunction, TypedStatement}; + + // wrap this expression in a function + let wrapper = TypedFunction { + arguments: vec![], + statements: vec![TypedStatement::Return(vec![c.expression])], + signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]), + }; + + let mut inlined_wrapper = reduce_function( + wrapper, + ConcreteGenericsAssignment::default(), + &self.program, + )?; + + if let TypedStatement::Return(mut expressions) = + inlined_wrapper.statements.pop().unwrap() + { + assert_eq!(expressions.len(), 1); + let constant_expression = expressions.pop().unwrap(); + + use crate::typed_absy::Constant; + if !constant_expression.is_constant() { + return Err(Error::ConstantReduction(id.id.to_string(), id.module)); + }; + + use crate::typed_absy::Typed; + if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>( + c.ty.clone(), + ) + .unwrap() + == constant_expression.get_type() + { + // add to the constant map + self.constants + .insert(id.clone(), constant_expression.clone()); + + // after we reduced a constant, propagate it through the whole program + self.update_program(); + + Ok(TypedConstantSymbolDeclaration { + id, + symbol: TypedConstantSymbol::Here(TypedConstant { + expression: constant_expression, + ty: c.ty, + }), + }) + } else { + Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, c.ty))) + } + } else { + Err(Error::ConstantReduction(id.id.to_string(), id.module)) + } + } + _ => unreachable!("all constants should be local"), + } + } +} diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index 3d6967438..401593c35 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -12,6 +12,7 @@ // - inline function calls. This includes applying shallow-ssa on the target function mod constants_reader; +mod constants_writer; mod inline; mod shallow_ssa; @@ -21,20 +22,19 @@ use crate::typed_absy::types::ConcreteGenericsAssignment; use crate::typed_absy::types::GGenericsAssignment; use crate::typed_absy::CanonicalConstantIdentifier; use crate::typed_absy::Folder; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::HashMap; use crate::typed_absy::{ - ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, DeclarationSignature, Expr, - FunctionCall, FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, - OwnedTypedModuleId, TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, + ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, + FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, TypedExpression, TypedExpressionList, TypedExpressionListInner, TypedFunction, - TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedModuleId, TypedProgram, - TypedStatement, TypedSymbolDeclaration, UExpression, UExpressionInner, Variable, + TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, TypedStatement, + UExpression, UExpressionInner, Variable, }; use zokrates_field::Field; -use self::constants_reader::ConstantsReader; +use self::constants_writer::ConstantsWriter; use self::shallow_ssa::ShallowTransformer; use crate::static_analysis::propagation::{Constants, Propagator}; @@ -44,161 +44,9 @@ use std::fmt; const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20); // A map to register the canonical value of all constants. The values must be literals. -type ConstantDefinitions<'ast, T> = +pub type ConstantDefinitions<'ast, T> = HashMap, TypedExpression<'ast, T>>; -// A folder to inline all constant definitions down to a single litteral. Also register them in the state for later use. -struct ConstantsBuilder<'ast, T> { - treated: HashSet, - constants: ConstantDefinitions<'ast, T>, - location: OwnedTypedModuleId, - program: TypedProgram<'ast, T>, -} - -impl<'ast, T: Field> ConstantsBuilder<'ast, T> { - fn with_program(program: TypedProgram<'ast, T>) -> Self { - ConstantsBuilder { - constants: ConstantDefinitions::default(), - location: program.main.clone(), - treated: HashSet::default(), - program, - } - } - - fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId { - let prev = self.location.clone(); - self.location = location; - self.treated.insert(self.location.clone()); - prev - } - - fn treated(&self, id: &TypedModuleId) -> bool { - self.treated.contains(id) - } - - fn update_program(&mut self) { - let mut p = TypedProgram { - main: "".into(), - modules: BTreeMap::default(), - }; - std::mem::swap(&mut self.program, &mut p); - let mut reader = ConstantsReader::with_constants(&self.constants); - self.program = reader.fold_program(p); - } - - fn update_symbol_declaration( - &self, - s: TypedSymbolDeclaration<'ast, T>, - ) -> TypedSymbolDeclaration<'ast, T> { - let mut reader = ConstantsReader::with_constants(&self.constants); - reader.fold_symbol_declaration(s) - } -} - -impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsBuilder<'ast, T> { - type Error = Error; - - fn fold_module_id( - &mut self, - id: OwnedTypedModuleId, - ) -> Result { - // anytime we encounter a module id, visit the corresponding module if it hasn't been done yet - if !self.treated(&id) { - let current_m_id = self.change_location(id.clone()); - // I did not find a way to achieve this without cloning the module. Assuming we do not clone: - // to fold the module, we need to consume it, so it gets removed from the modules - // but to inline the calls while folding the module, all modules must be present - // therefore we clone... - // this does not lead to a module being folded more than once, as the first time - // we change location to this module, it's added to the `treated` set - let m = self.program.modules.get(&id).cloned().unwrap(); - let m = self.fold_module(m)?; - self.program.modules.insert(id.clone(), m); - self.change_location(current_m_id); - } - Ok(id) - } - - fn fold_symbol_declaration( - &mut self, - s: TypedSymbolDeclaration<'ast, T>, - ) -> Result, Self::Error> { - // before we treat the symbol, propagate the constants into it, as may be using constants defined earlier in this module. - let s = self.update_symbol_declaration(s); - - let s = fold_symbol_declaration(self, s)?; - - // after we treat the symbol, propagate again, as treating this symbol may have triggered checking another module, resolving new constants which this symbol may be using. - Ok(self.update_symbol_declaration(s)) - } - - fn fold_constant_symbol_declaration( - &mut self, - d: TypedConstantSymbolDeclaration<'ast, T>, - ) -> Result, Self::Error> { - let id = self.fold_canonical_constant_identifier(d.id)?; - - match d.symbol { - TypedConstantSymbol::Here(c) => { - let c = self.fold_constant(c)?; - - // wrap this expression in a function - let wrapper = TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![c.expression])], - signature: DeclarationSignature::new().outputs(vec![c.ty.clone()]), - }; - - let mut inlined_wrapper = reduce_function( - wrapper, - ConcreteGenericsAssignment::default(), - &self.program, - )?; - - if let TypedStatement::Return(mut expressions) = - inlined_wrapper.statements.pop().unwrap() - { - assert_eq!(expressions.len(), 1); - let constant_expression = expressions.pop().unwrap(); - - use crate::typed_absy::Constant; - if !constant_expression.is_constant() { - return Err(Error::ConstantReduction(id.id.to_string(), id.module)); - }; - - use crate::typed_absy::Typed; - if crate::typed_absy::types::try_from_g_type::<_, UExpression<'ast, T>>( - c.ty.clone(), - ) - .unwrap() - == constant_expression.get_type() - { - // add to the constant map - self.constants - .insert(id.clone(), constant_expression.clone()); - - // after we reduced a constant, propagate it through the whole program - self.update_program(); - - Ok(TypedConstantSymbolDeclaration { - id, - symbol: TypedConstantSymbol::Here(TypedConstant { - expression: constant_expression, - ty: c.ty, - }), - }) - } else { - Err(Error::Type(format!("Expression of type `{}` cannot be assigned to constant `{}` of type `{}`", constant_expression.get_type(), id, c.ty))) - } - } else { - Err(Error::ConstantReduction(id.id.to_string(), id.module)) - } - } - _ => unreachable!("all constants should be local"), - } - } -} - // An SSA version map, giving access to the latest version number for each identifier pub type Versions<'ast> = HashMap, usize>; @@ -662,9 +510,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { pub fn reduce_program(p: TypedProgram) -> Result, Error> { // inline all constants and replace them in the program - let mut constant_calls_inliner = ConstantsBuilder::with_program(p.clone()); + let mut constants_writer = ConstantsWriter::with_program(p.clone()); - let p = constant_calls_inliner.fold_program(p)?; + let p = constants_writer.fold_program(p)?; // inline starting from main let main_module = p.modules.get(&p.main).unwrap().clone(); @@ -688,12 +536,11 @@ pub fn reduce_program(p: TypedProgram) -> Result, E modules: vec![( p.main.clone(), TypedModule { - symbols: vec![TypedSymbolDeclaration::Function( - TypedFunctionSymbolDeclaration { - key: decl.key.clone(), - symbol: TypedFunctionSymbol::Here(main_function), - }, - )], + symbols: vec![TypedFunctionSymbolDeclaration::new( + decl.key.clone(), + TypedFunctionSymbol::Here(main_function), + ) + .into()], }, )] .into_iter() From 43f4934586b0cef80f56f8ac6631c32953c0cf09 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 13:00:36 +0300 Subject: [PATCH 50/78] detect out of bounds reads and writes --- .../compile_errors/out_of_bounds_read.zok | 7 ++ .../compile_errors/out_of_bounds_write.zok | 7 ++ .../constant_argument_checker.rs | 23 ++++-- zokrates_core/src/static_analysis/mod.rs | 15 ++++ .../src/static_analysis/out_of_bounds.rs | 78 +++++++++++++++++++ zokrates_core/src/typed_absy/result_folder.rs | 25 +++--- 6 files changed, 138 insertions(+), 17 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/out_of_bounds_read.zok create mode 100644 zokrates_cli/examples/compile_errors/out_of_bounds_write.zok create mode 100644 zokrates_core/src/static_analysis/out_of_bounds.rs diff --git a/zokrates_cli/examples/compile_errors/out_of_bounds_read.zok b/zokrates_cli/examples/compile_errors/out_of_bounds_read.zok new file mode 100644 index 000000000..7401e1d5d --- /dev/null +++ b/zokrates_cli/examples/compile_errors/out_of_bounds_read.zok @@ -0,0 +1,7 @@ +def foo(field[1] a) -> field[1]: + return a + +def main(field a): + field[1] h = foo([a]) + field f = h[1] + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok b/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok new file mode 100644 index 000000000..7401e1d5d --- /dev/null +++ b/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok @@ -0,0 +1,7 @@ +def foo(field[1] a) -> field[1]: + return a + +def main(field a): + field[1] h = foo([a]) + field f = h[1] + return \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_argument_checker.rs b/zokrates_core/src/static_analysis/constant_argument_checker.rs index 91ec166ef..42bb3b668 100644 --- a/zokrates_core/src/static_analysis/constant_argument_checker.rs +++ b/zokrates_core/src/static_analysis/constant_argument_checker.rs @@ -5,7 +5,9 @@ use crate::typed_absy::{ result_folder::{fold_expression_list_inner, fold_uint_expression_inner}, Constant, TypedExpressionListInner, Types, UBitwidth, UExpressionInner, }; +use std::fmt; use zokrates_field::Field; + pub struct ConstantArgumentChecker; impl ConstantArgumentChecker { @@ -14,7 +16,14 @@ impl ConstantArgumentChecker { } } -pub type Error = String; +#[derive(Debug)] +pub struct Error(String); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { type Error = Error; @@ -31,11 +40,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { match by.as_inner() { UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), - by => Err(format!( + by => Err(Error(format!( "Cannot shift by a variable value, found `{} << {}`", e, by.clone().annotate(UBitwidth::B32) - )), + ))), } } UExpressionInner::RightShift(box e, box by) => { @@ -44,11 +53,11 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { match by.as_inner() { UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), - by => Err(format!( + by => Err(Error(format!( "Cannot shift by a variable value, found `{} >> {}`", e, by.clone().annotate(UBitwidth::B32) - )), + ))), } } e => fold_uint_expression_inner(self, bitwidth, e), @@ -74,10 +83,10 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker { arguments, )) } else { - Err(format!( + Err(Error(format!( "Cannot compare to a variable value, found `{}`", arguments[1] - )) + ))) } } l => fold_expression_list_inner(self, tys, l), diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 7078f02ba..4c3ce3e07 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -9,6 +9,7 @@ mod constant_argument_checker; mod constant_inliner; mod flat_propagation; mod flatten_complex_types; +mod out_of_bounds; mod propagation; mod reducer; mod uint_optimizer; @@ -19,6 +20,7 @@ mod zir_propagation; use self::branch_isolator::Isolator; use self::constant_argument_checker::ConstantArgumentChecker; use self::flatten_complex_types::Flattener; +use self::out_of_bounds::OutOfBoundsChecker; use self::propagation::Propagator; use self::reducer::reduce_program; use self::uint_optimizer::UintOptimizer; @@ -48,6 +50,7 @@ pub enum Error { NonConstantArgument(self::constant_argument_checker::Error), ConstantInliner(self::constant_inliner::Error), UnconstrainedVariable(self::unconstrained_vars::Error), + OutOfBounds(self::out_of_bounds::Error), } impl From for Error { @@ -74,6 +77,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: out_of_bounds::Error) -> Self { + Error::OutOfBounds(e) + } +} + impl From for Error { fn from(e: constant_argument_checker::Error) -> Self { Error::NonConstantArgument(e) @@ -95,6 +104,7 @@ impl fmt::Display for Error { Error::NonConstantArgument(e) => write!(f, "{}", e), Error::ConstantInliner(e) => write!(f, "{}", e), Error::UnconstrainedVariable(e) => write!(f, "{}", e), + Error::OutOfBounds(e) => write!(f, "{}", e), } } } @@ -141,6 +151,11 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = ConstantArgumentChecker::check(r).map_err(Error::from)?; log::trace!("\n{}", r); + // detect out of bounds reads and writes + log::debug!("Static analyser: Detect out of bound accesses"); + let r = OutOfBoundsChecker::check(r).map_err(Error::from)?; + log::trace!("\n{}", r); + // convert to zir, removing complex types log::debug!("Static analyser: Convert to zir"); let zir = Flattener::flatten(r); diff --git a/zokrates_core/src/static_analysis/out_of_bounds.rs b/zokrates_core/src/static_analysis/out_of_bounds.rs new file mode 100644 index 000000000..679ca152d --- /dev/null +++ b/zokrates_core/src/static_analysis/out_of_bounds.rs @@ -0,0 +1,78 @@ +use crate::typed_absy::{ + result_folder::*, Expr, SelectExpression, SelectOrExpression, Type, TypedAssignee, + TypedProgram, UExpressionInner, +}; +use std::fmt; +use zokrates_field::Field; + +#[derive(Default)] +pub struct OutOfBoundsChecker; + +#[derive(Debug)] +pub struct Error(String); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} +impl OutOfBoundsChecker { + pub fn check(p: TypedProgram) -> Result, Error> { + Self::default().fold_program(p) + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for OutOfBoundsChecker { + type Error = Error; + + fn fold_select_expression>( + &mut self, + _: &E::Ty, + s: SelectExpression<'ast, T, E>, + ) -> Result, Self::Error> { + match (s.index.as_inner(), s.array.size().as_inner()) { + (UExpressionInner::Value(index), UExpressionInner::Value(size)) if index >= size => { + Err(Error(format!( + "Out of bounds access `{}` because `{}` has size {}", + s, s.array, size + ))) + } + _ => Ok(SelectOrExpression::Select(s)), + } + } + + fn fold_assignee( + &mut self, + a: TypedAssignee<'ast, T>, + ) -> Result, Error> { + match a { + TypedAssignee::Select(box array, box index) => { + use crate::typed_absy::Typed; + + let array = self.fold_assignee(array)?; + + let size = match array.get_type() { + Type::Array(array_ty) => match array_ty.size.as_inner() { + UExpressionInner::Value(size) => *size, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + match index.as_inner() { + UExpressionInner::Value(i) if i >= &size => Err(Error(format!( + "Out of bounds write to `{}` because `{}` has size {}", + TypedAssignee::Select(box array.clone(), box index), + array, + size + ))), + _ => Ok(TypedAssignee::Select( + box self.fold_assignee(array)?, + box self.fold_uint_expression(index)?, + )), + } + } + a => fold_assignee(self, a), + } + } +} diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index c85e97a7d..468474ba4 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -294,16 +294,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { &mut self, a: TypedAssignee<'ast, T>, ) -> Result, Self::Error> { - match a { - TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(self.fold_variable(v)?)), - TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select( - box self.fold_assignee(a)?, - box self.fold_uint_expression(index)?, - )), - TypedAssignee::Member(box s, m) => { - Ok(TypedAssignee::Member(box self.fold_assignee(s)?, m)) - } - } + fold_assignee(self, a) } fn fold_statement( @@ -518,6 +509,20 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( Ok(e) } +pub fn fold_assignee<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + a: TypedAssignee<'ast, T>, +) -> Result, F::Error> { + match a { + TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(f.fold_variable(v)?)), + TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select( + box f.fold_assignee(a)?, + box f.fold_uint_expression(index)?, + )), + TypedAssignee::Member(box s, m) => Ok(TypedAssignee::Member(box f.fold_assignee(s)?, m)), + } +} + pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, ty: &StructType<'ast, T>, From 1d1dca3b65d365526a6b709722c01c0897314663 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 13:05:12 +0300 Subject: [PATCH 51/78] changelog --- changelogs/unreleased/1013-schaeff | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1013-schaeff diff --git a/changelogs/unreleased/1013-schaeff b/changelogs/unreleased/1013-schaeff new file mode 100644 index 000000000..b47a580e2 --- /dev/null +++ b/changelogs/unreleased/1013-schaeff @@ -0,0 +1 @@ +Handle out of bound accesses gracefully \ No newline at end of file From 77a3888600066aa19d47d4d95e9fb53dee4d2976 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 21 Sep 2021 13:07:26 +0300 Subject: [PATCH 52/78] fix example --- zokrates_cli/examples/compile_errors/out_of_bounds_write.zok | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok b/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok index 7401e1d5d..93904120f 100644 --- a/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok +++ b/zokrates_cli/examples/compile_errors/out_of_bounds_write.zok @@ -3,5 +3,5 @@ def foo(field[1] a) -> field[1]: def main(field a): field[1] h = foo([a]) - field f = h[1] + h[1] = 1 return \ No newline at end of file From f41fddb7ac7a223d3f73fc77c0a92965ef1d5686 Mon Sep 17 00:00:00 2001 From: Christoph Michelbach Date: Tue, 21 Sep 2021 14:16:37 +0200 Subject: [PATCH 53/78] Eliminate redundant code. --- zokrates_core/src/proof_system/scheme/gm17.rs | 4 +- .../src/proof_system/scheme/groth16.rs | 4 +- .../src/proof_system/scheme/pghr13.rs | 4 +- zokrates_core/src/proof_system/solidity.rs | 157 ++---------------- 4 files changed, 20 insertions(+), 149 deletions(-) diff --git a/zokrates_core/src/proof_system/scheme/gm17.rs b/zokrates_core/src/proof_system/scheme/gm17.rs index a50cd8833..bee43d26a 100644 --- a/zokrates_core/src/proof_system/scheme/gm17.rs +++ b/zokrates_core/src/proof_system/scheme/gm17.rs @@ -1,5 +1,5 @@ use crate::proof_system::scheme::{NonUniversalScheme, Scheme}; -use crate::proof_system::solidity::{SOLIDITY_G2_ADDITION_LIB, SOLIDITY_PAIRING_LIB}; +use crate::proof_system::solidity::{solidity_pairing_lib, SOLIDITY_G2_ADDITION_LIB}; use crate::proof_system::{ G1Affine, G2Affine, G2AffineFq, SolidityCompatibleField, SolidityCompatibleScheme, }; @@ -50,7 +50,7 @@ impl SolidityCompatibleScheme f fn export_solidity_verifier(vk: >::VerificationKey) -> String { let (mut template_text, solidity_pairing_lib) = ( String::from(CONTRACT_TEMPLATE), - String::from(SOLIDITY_PAIRING_LIB), + String::from(solidity_pairing_lib(false)), ); // replace things in template diff --git a/zokrates_core/src/proof_system/scheme/groth16.rs b/zokrates_core/src/proof_system/scheme/groth16.rs index a998c953c..077f6f2e8 100644 --- a/zokrates_core/src/proof_system/scheme/groth16.rs +++ b/zokrates_core/src/proof_system/scheme/groth16.rs @@ -1,5 +1,5 @@ use crate::proof_system::scheme::{NonUniversalScheme, Scheme}; -use crate::proof_system::solidity::SOLIDITY_PAIRING_LIB_SANS_BN256G2; +use crate::proof_system::solidity::solidity_pairing_lib; use crate::proof_system::{G1Affine, G2Affine, SolidityCompatibleField, SolidityCompatibleScheme}; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -34,7 +34,7 @@ impl SolidityCompatibleScheme for G16 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { let (mut template_text, solidity_pairing_lib_sans_bn256g2) = ( String::from(CONTRACT_TEMPLATE), - String::from(SOLIDITY_PAIRING_LIB_SANS_BN256G2), + String::from(solidity_pairing_lib(false)), ); let vk_regex = Regex::new(r#"(<%vk_[^i%]*%>)"#).unwrap(); diff --git a/zokrates_core/src/proof_system/scheme/pghr13.rs b/zokrates_core/src/proof_system/scheme/pghr13.rs index 5362f245f..bc7e85d64 100644 --- a/zokrates_core/src/proof_system/scheme/pghr13.rs +++ b/zokrates_core/src/proof_system/scheme/pghr13.rs @@ -1,5 +1,5 @@ use crate::proof_system::scheme::{NonUniversalScheme, Scheme}; -use crate::proof_system::solidity::{SOLIDITY_G2_ADDITION_LIB, SOLIDITY_PAIRING_LIB}; +use crate::proof_system::solidity::{solidity_pairing_lib, SOLIDITY_G2_ADDITION_LIB}; use crate::proof_system::{G1Affine, G2Affine, SolidityCompatibleField, SolidityCompatibleScheme}; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -43,7 +43,7 @@ impl SolidityCompatibleScheme for PGHR13 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { let (mut template_text, solidity_pairing_lib) = ( String::from(CONTRACT_TEMPLATE), - String::from(SOLIDITY_PAIRING_LIB), + String::from(solidity_pairing_lib(false)), ); // replace things in template diff --git a/zokrates_core/src/proof_system/solidity.rs b/zokrates_core/src/proof_system/solidity.rs index d9380235c..92e3a6409 100644 --- a/zokrates_core/src/proof_system/solidity.rs +++ b/zokrates_core/src/proof_system/solidity.rs @@ -406,7 +406,8 @@ library BN256G2 { } "#; -pub const SOLIDITY_PAIRING_LIB: &str = r#"// This file is MIT Licensed. +pub fn solidity_pairing_lib(with_g2_addition: bool) -> String { + let pairingLibBeginning = r#"// This file is MIT Licensed. // // Copyright 2017 Christian Reitwiessner // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: @@ -459,154 +460,16 @@ library Pairing { } require(success); } +"#; + + let pairingLibG2Addition = r#" /// @return r the sum of two points of G2 function addition(G2Point memory p1, G2Point memory p2) internal view returns (G2Point memory r) { (r.X[0], r.X[1], r.Y[0], r.Y[1]) = BN256G2.ECTwistAdd(p1.X[0],p1.X[1],p1.Y[0],p1.Y[1],p2.X[0],p2.X[1],p2.Y[0],p2.Y[1]); } - /// @return r the product of a point on G1 and a scalar, i.e. - /// p == p.scalar_mul(1) and p.addition(p) == p.scalar_mul(2) for all points p. - function scalar_mul(G1Point memory p, uint s) internal view returns (G1Point memory r) { - uint[3] memory input; - input[0] = p.X; - input[1] = p.Y; - input[2] = s; - bool success; - assembly { - success := staticcall(sub(gas(), 2000), 7, input, 0x80, r, 0x60) - // Use "invalid" to make gas estimation work - switch success case 0 { invalid() } - } - require (success); - } - /// @return the result of computing the pairing check - /// e(p1[0], p2[0]) * .... * e(p1[n], p2[n]) == 1 - /// For example pairing([P1(), P1().negate()], [P2(), P2()]) should - /// return true. - function pairing(G1Point[] memory p1, G2Point[] memory p2) internal view returns (bool) { - require(p1.length == p2.length); - uint elements = p1.length; - uint inputSize = elements * 6; - uint[] memory input = new uint[](inputSize); - for (uint i = 0; i < elements; i++) - { - input[i * 6 + 0] = p1[i].X; - input[i * 6 + 1] = p1[i].Y; - input[i * 6 + 2] = p2[i].X[1]; - input[i * 6 + 3] = p2[i].X[0]; - input[i * 6 + 4] = p2[i].Y[1]; - input[i * 6 + 5] = p2[i].Y[0]; - } - uint[1] memory out; - bool success; - assembly { - success := staticcall(sub(gas(), 2000), 8, add(input, 0x20), mul(inputSize, 0x20), out, 0x20) - // Use "invalid" to make gas estimation work - switch success case 0 { invalid() } - } - require(success); - return out[0] != 0; - } - /// Convenience method for a pairing check for two pairs. - function pairingProd2(G1Point memory a1, G2Point memory a2, G1Point memory b1, G2Point memory b2) internal view returns (bool) { - G1Point[] memory p1 = new G1Point[](2); - G2Point[] memory p2 = new G2Point[](2); - p1[0] = a1; - p1[1] = b1; - p2[0] = a2; - p2[1] = b2; - return pairing(p1, p2); - } - /// Convenience method for a pairing check for three pairs. - function pairingProd3( - G1Point memory a1, G2Point memory a2, - G1Point memory b1, G2Point memory b2, - G1Point memory c1, G2Point memory c2 - ) internal view returns (bool) { - G1Point[] memory p1 = new G1Point[](3); - G2Point[] memory p2 = new G2Point[](3); - p1[0] = a1; - p1[1] = b1; - p1[2] = c1; - p2[0] = a2; - p2[1] = b2; - p2[2] = c2; - return pairing(p1, p2); - } - /// Convenience method for a pairing check for four pairs. - function pairingProd4( - G1Point memory a1, G2Point memory a2, - G1Point memory b1, G2Point memory b2, - G1Point memory c1, G2Point memory c2, - G1Point memory d1, G2Point memory d2 - ) internal view returns (bool) { - G1Point[] memory p1 = new G1Point[](4); - G2Point[] memory p2 = new G2Point[](4); - p1[0] = a1; - p1[1] = b1; - p1[2] = c1; - p1[3] = d1; - p2[0] = a2; - p2[1] = b2; - p2[2] = c2; - p2[3] = d2; - return pairing(p1, p2); - } -} "#; -pub const SOLIDITY_PAIRING_LIB_SANS_BN256G2: &str = r#"// This file is MIT Licensed. -// -// Copyright 2017 Christian Reitwiessner -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -pragma solidity ^0.8.0; -library Pairing { - struct G1Point { - uint X; - uint Y; - } - // Encoding of field elements is: X[0] * z + X[1] - struct G2Point { - uint[2] X; - uint[2] Y; - } - /// @return the generator of G1 - function P1() pure internal returns (G1Point memory) { - return G1Point(1, 2); - } - /// @return the generator of G2 - function P2() pure internal returns (G2Point memory) { - return G2Point( - [10857046999023057135944570762232829481370756359578518086990519993285655852781, - 11559732032986387107991004021392285783925812861821192530917403151452391805634], - [8495653923123431417604973247489272438418190587263600148770280649306958101930, - 4082367875863433681332203403145435568316851327593401208105741076214120093531] - ); - } - /// @return the negation of p, i.e. p.addition(p.negate()) should be zero. - function negate(G1Point memory p) pure internal returns (G1Point memory) { - // The prime q in the base field F_q for G1 - uint q = 21888242871839275222246405745257275088696311157297823662689037894645226208583; - if (p.X == 0 && p.Y == 0) - return G1Point(0, 0); - return G1Point(p.X, q - (p.Y % q)); - } - /// @return r the sum of two points of G1 - function addition(G1Point memory p1, G1Point memory p2) internal view returns (G1Point memory r) { - uint[4] memory input; - input[0] = p1.X; - input[1] = p1.Y; - input[2] = p2.X; - input[3] = p2.Y; - bool success; - assembly { - success := staticcall(sub(gas(), 2000), 6, input, 0xc0, r, 0x60) - // Use "invalid" to make gas estimation work - switch success case 0 { invalid() } - } - require(success); - } + let pairingLibEnding = r#" /// @return r the product of a point on G1 and a scalar, i.e. /// p == p.scalar_mul(1) and p.addition(p) == p.scalar_mul(2) for all points p. function scalar_mul(G1Point memory p, uint s) internal view returns (G1Point memory r) { @@ -697,3 +560,11 @@ library Pairing { } } "#; + + let pairingLib = if (!with_g2_addition) { + [pairingLibBeginning, pairingLibEnding].join("\n") + } else { + [pairingLibBeginning, pairingLibG2Addition, pairingLibEnding].join("\n") + }; + return pairingLib; +} From d21a032dca111c0f80c4854d29f8d59ddd2ef9f9 Mon Sep 17 00:00:00 2001 From: Christoph Michelbach Date: Tue, 21 Sep 2021 14:24:51 +0200 Subject: [PATCH 54/78] Fix formatting violations that are not fixed by cargo fmt. --- zokrates_core/src/proof_system/solidity.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/zokrates_core/src/proof_system/solidity.rs b/zokrates_core/src/proof_system/solidity.rs index 92e3a6409..975e1b5b1 100644 --- a/zokrates_core/src/proof_system/solidity.rs +++ b/zokrates_core/src/proof_system/solidity.rs @@ -407,7 +407,7 @@ library BN256G2 { "#; pub fn solidity_pairing_lib(with_g2_addition: bool) -> String { - let pairingLibBeginning = r#"// This file is MIT Licensed. + let pairing_lib_beginning = r#"// This file is MIT Licensed. // // Copyright 2017 Christian Reitwiessner // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: @@ -462,14 +462,14 @@ library Pairing { } "#; - let pairingLibG2Addition = r#" + let pairing_lib_g2_addition = r#" /// @return r the sum of two points of G2 function addition(G2Point memory p1, G2Point memory p2) internal view returns (G2Point memory r) { (r.X[0], r.X[1], r.Y[0], r.Y[1]) = BN256G2.ECTwistAdd(p1.X[0],p1.X[1],p1.Y[0],p1.Y[1],p2.X[0],p2.X[1],p2.Y[0],p2.Y[1]); } "#; - let pairingLibEnding = r#" + let pairing_lib_ending = r#" /// @return r the product of a point on G1 and a scalar, i.e. /// p == p.scalar_mul(1) and p.addition(p) == p.scalar_mul(2) for all points p. function scalar_mul(G1Point memory p, uint s) internal view returns (G1Point memory r) { @@ -561,10 +561,10 @@ library Pairing { } "#; - let pairingLib = if (!with_g2_addition) { - [pairingLibBeginning, pairingLibEnding].join("\n") + let pairing_lib = if !with_g2_addition { + [pairing_lib_beginning, pairing_lib_ending].join("\n") } else { - [pairingLibBeginning, pairingLibG2Addition, pairingLibEnding].join("\n") + [pairing_lib_beginning, pairing_lib_g2_addition, pairing_lib_ending].join("\n") }; - return pairingLib; + return pairing_lib; } From d55ddc639d4d3ca11085f18542d8da671fe646ad Mon Sep 17 00:00:00 2001 From: Darko Macesic Date: Thu, 23 Sep 2021 15:33:09 +0200 Subject: [PATCH 55/78] Update unconstrained_vars.rs --- zokrates_core/src/static_analysis/unconstrained_vars.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/unconstrained_vars.rs b/zokrates_core/src/static_analysis/unconstrained_vars.rs index dcade6835..5e3c78c07 100644 --- a/zokrates_core/src/static_analysis/unconstrained_vars.rs +++ b/zokrates_core/src/static_analysis/unconstrained_vars.rs @@ -18,7 +18,7 @@ impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "Found unconstrained variables during IR analysis (found {} occurrence{})", + "Found unconstrained variables during IR analysis (found {} occurrence{}). If this is intentional, use the `--allow-unconstrained-variables` flag.", self.0, if self.0 == 1 { "" } else { "s" } ) From 4993b71455ac65c1148aa48574e86a044de96bdd Mon Sep 17 00:00:00 2001 From: Christoph Michelbach Date: Thu, 23 Sep 2021 21:53:23 +0200 Subject: [PATCH 56/78] Format code again. --- zokrates_core/src/proof_system/solidity.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/zokrates_core/src/proof_system/solidity.rs b/zokrates_core/src/proof_system/solidity.rs index 975e1b5b1..36f8933c0 100644 --- a/zokrates_core/src/proof_system/solidity.rs +++ b/zokrates_core/src/proof_system/solidity.rs @@ -564,7 +564,12 @@ library Pairing { let pairing_lib = if !with_g2_addition { [pairing_lib_beginning, pairing_lib_ending].join("\n") } else { - [pairing_lib_beginning, pairing_lib_g2_addition, pairing_lib_ending].join("\n") + [ + pairing_lib_beginning, + pairing_lib_g2_addition, + pairing_lib_ending, + ] + .join("\n") }; return pairing_lib; } From 58389bf8095961ec5cca9339bf58d63716554206 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Sep 2021 13:32:59 +0200 Subject: [PATCH 57/78] fix tests --- zokrates_core/src/static_analysis/unconstrained_vars.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/static_analysis/unconstrained_vars.rs b/zokrates_core/src/static_analysis/unconstrained_vars.rs index 5e3c78c07..52793cece 100644 --- a/zokrates_core/src/static_analysis/unconstrained_vars.rs +++ b/zokrates_core/src/static_analysis/unconstrained_vars.rs @@ -86,7 +86,7 @@ mod tests { let result = UnconstrainedVariableDetector::detect(&p); assert_eq!( result.expect_err("expected an error").to_string(), - "Found unconstrained variables during IR analysis (found 1 occurrence)" + "Found unconstrained variables during IR analysis (found 1 occurrence). If this is intentional, use the `--allow-unconstrained-variables` flag." ); } From 3c97ecfbcf9ea47eedb7ff13a61235bb55761b03 Mon Sep 17 00:00:00 2001 From: Christoph Michelbach Date: Fri, 24 Sep 2021 14:41:05 +0200 Subject: [PATCH 58/78] Apply cargo clippy rules. --- zokrates_core/src/proof_system/scheme/gm17.rs | 6 ++---- zokrates_core/src/proof_system/scheme/groth16.rs | 6 ++---- zokrates_core/src/proof_system/scheme/pghr13.rs | 6 ++---- zokrates_core/src/proof_system/solidity.rs | 5 ++--- 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/zokrates_core/src/proof_system/scheme/gm17.rs b/zokrates_core/src/proof_system/scheme/gm17.rs index bee43d26a..4bc389f2b 100644 --- a/zokrates_core/src/proof_system/scheme/gm17.rs +++ b/zokrates_core/src/proof_system/scheme/gm17.rs @@ -48,10 +48,8 @@ impl Scheme for GM17 { impl SolidityCompatibleScheme for GM17 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { - let (mut template_text, solidity_pairing_lib) = ( - String::from(CONTRACT_TEMPLATE), - String::from(solidity_pairing_lib(false)), - ); + let (mut template_text, solidity_pairing_lib) = + (String::from(CONTRACT_TEMPLATE), solidity_pairing_lib(false)); // replace things in template let vk_regex = Regex::new(r#"(<%vk_[^i%]*%>)"#).unwrap(); diff --git a/zokrates_core/src/proof_system/scheme/groth16.rs b/zokrates_core/src/proof_system/scheme/groth16.rs index 077f6f2e8..a3f3d1d03 100644 --- a/zokrates_core/src/proof_system/scheme/groth16.rs +++ b/zokrates_core/src/proof_system/scheme/groth16.rs @@ -32,10 +32,8 @@ impl NonUniversalScheme for G16 {} impl SolidityCompatibleScheme for G16 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { - let (mut template_text, solidity_pairing_lib_sans_bn256g2) = ( - String::from(CONTRACT_TEMPLATE), - String::from(solidity_pairing_lib(false)), - ); + let (mut template_text, solidity_pairing_lib_sans_bn256g2) = + (String::from(CONTRACT_TEMPLATE), solidity_pairing_lib(false)); let vk_regex = Regex::new(r#"(<%vk_[^i%]*%>)"#).unwrap(); let vk_gamma_abc_len_regex = Regex::new(r#"(<%vk_gamma_abc_length%>)"#).unwrap(); diff --git a/zokrates_core/src/proof_system/scheme/pghr13.rs b/zokrates_core/src/proof_system/scheme/pghr13.rs index bc7e85d64..567a3441b 100644 --- a/zokrates_core/src/proof_system/scheme/pghr13.rs +++ b/zokrates_core/src/proof_system/scheme/pghr13.rs @@ -41,10 +41,8 @@ impl NonUniversalScheme for PGHR13 {} impl SolidityCompatibleScheme for PGHR13 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { - let (mut template_text, solidity_pairing_lib) = ( - String::from(CONTRACT_TEMPLATE), - String::from(solidity_pairing_lib(false)), - ); + let (mut template_text, solidity_pairing_lib) = + (String::from(CONTRACT_TEMPLATE), solidity_pairing_lib(false)); // replace things in template let vk_regex = Regex::new(r#"(<%vk_[^i%]*%>)"#).unwrap(); diff --git a/zokrates_core/src/proof_system/solidity.rs b/zokrates_core/src/proof_system/solidity.rs index 36f8933c0..c5471962b 100644 --- a/zokrates_core/src/proof_system/solidity.rs +++ b/zokrates_core/src/proof_system/solidity.rs @@ -561,7 +561,7 @@ library Pairing { } "#; - let pairing_lib = if !with_g2_addition { + if !with_g2_addition { [pairing_lib_beginning, pairing_lib_ending].join("\n") } else { [ @@ -570,6 +570,5 @@ library Pairing { pairing_lib_ending, ] .join("\n") - }; - return pairing_lib; + } } From ca55c984e3a85db79b88e4f773427d32132c85d1 Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 24 Sep 2021 17:54:16 +0300 Subject: [PATCH 59/78] address review comments --- .../examples/array_generic_inference.zok | 5 ++-- zokrates_cli/examples/call_in_constant.zok | 2 -- .../compile_errors/ambiguous_generic_call.zok | 2 +- .../examples/struct_generic_inference.zok | 5 ++-- .../src/static_analysis/constant_resolver.rs | 26 +++++++++---------- .../reducer/constants_writer.rs | 2 +- 6 files changed, 19 insertions(+), 23 deletions(-) diff --git a/zokrates_cli/examples/array_generic_inference.zok b/zokrates_cli/examples/array_generic_inference.zok index c0cee5776..40da76a02 100644 --- a/zokrates_cli/examples/array_generic_inference.zok +++ b/zokrates_cli/examples/array_generic_inference.zok @@ -1,12 +1,11 @@ def myFct(u64[N] ignored) -> u64[N2]: assert(2*N == N2) - return [0; N2] - const u32 N = 3 + const u32 N2 = 2*N + def main(u64[N] arg) -> bool: u64[N2] someVariable = myFct(arg) - return true \ No newline at end of file diff --git a/zokrates_cli/examples/call_in_constant.zok b/zokrates_cli/examples/call_in_constant.zok index b334ed260..2d1312b50 100644 --- a/zokrates_cli/examples/call_in_constant.zok +++ b/zokrates_cli/examples/call_in_constant.zok @@ -1,5 +1,3 @@ -// calling a function inside a constant definition is not possible yet - def yes() -> bool: return true diff --git a/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok b/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok index 5224a4be0..5cb7685e5 100644 --- a/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok +++ b/zokrates_cli/examples/compile_errors/ambiguous_generic_call.zok @@ -1,6 +1,6 @@ // this should not compile, as A == B -const u32 A = 2 +const u32 A = 1 const u32 B = 1 def foo(field[A] a) -> bool: diff --git a/zokrates_cli/examples/struct_generic_inference.zok b/zokrates_cli/examples/struct_generic_inference.zok index 07dbb53bb..0e094a2ac 100644 --- a/zokrates_cli/examples/struct_generic_inference.zok +++ b/zokrates_cli/examples/struct_generic_inference.zok @@ -4,13 +4,12 @@ struct SomeStruct { def myFct(SomeStruct ignored) -> u32[N2]: assert(2*N == N2) - return [N3; N2] - const u32 N = 3 + const u32 N2 = 2*N + def main(SomeStruct arg) -> u32: u32[N2] someVariable = myFct::<_, _, 42>(arg) - return someVariable[0] \ No newline at end of file diff --git a/zokrates_core/src/static_analysis/constant_resolver.rs b/zokrates_core/src/static_analysis/constant_resolver.rs index b6c4f1368..9f9ed5d2e 100644 --- a/zokrates_core/src/static_analysis/constant_resolver.rs +++ b/zokrates_core/src/static_analysis/constant_resolver.rs @@ -11,19 +11,19 @@ use zokrates_field::Field; type ProgramConstants<'ast, T> = HashMap, TypedConstant<'ast, T>>>; -pub struct ConstantInliner<'ast, T> { +pub struct ConstantResolver<'ast, T> { modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, constants: ProgramConstants<'ast, T>, } -impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { +impl<'ast, 'a, T: Field> ConstantResolver<'ast, T> { pub fn new( modules: TypedModules<'ast, T>, location: OwnedTypedModuleId, constants: ProgramConstants<'ast, T>, ) -> Self { - ConstantInliner { + ConstantResolver { modules, location, constants, @@ -31,7 +31,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { } pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { let constants = ProgramConstants::new(); - let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants); + let mut inliner = ConstantResolver::new(p.modules.clone(), p.main.clone(), constants); inliner.fold_program(p) } @@ -57,7 +57,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> { } } -impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { +impl<'ast, T: Field> Folder<'ast, T> for ConstantResolver<'ast, T> { fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { self.fold_module_id(p.main.clone()); @@ -89,7 +89,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> { TypedConstantSymbol::There(imported_id) => { // visit the imported symbol. This triggers visiting the corresponding module if needed let imported_id = self.fold_canonical_constant_identifier(imported_id); - // after that, the constant must have been defined defined in the global map + // after that, the constant must have been defined in the global map self.get_constant(&imported_id).unwrap() } TypedConstantSymbol::Here(c) => fold_constant(self, c), @@ -171,7 +171,7 @@ mod tests { let expected_program = program.clone(); - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); assert_eq!(program, expected_program) } @@ -229,7 +229,7 @@ mod tests { let expected_program = program.clone(); - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); assert_eq!(program, expected_program) } @@ -290,7 +290,7 @@ mod tests { let expected_program = program.clone(); - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); assert_eq!(program, expected_program) } @@ -373,7 +373,7 @@ mod tests { let expected_program = program.clone(); - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); assert_eq!(program, expected_program) } @@ -446,7 +446,7 @@ mod tests { let expected_program = program.clone(); - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); assert_eq!(program, expected_program) } @@ -566,7 +566,7 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); let expected_main_module = TypedModule { symbols: vec![ TypedConstantSymbolDeclaration::new( @@ -764,7 +764,7 @@ mod tests { .collect(), }; - let program = ConstantInliner::inline(program); + let program = ConstantResolver::inline(program); let expected_main_module = TypedModule { symbols: vec![ TypedConstantSymbolDeclaration::new( diff --git a/zokrates_core/src/static_analysis/reducer/constants_writer.rs b/zokrates_core/src/static_analysis/reducer/constants_writer.rs index dfec62236..8b6f645a8 100644 --- a/zokrates_core/src/static_analysis/reducer/constants_writer.rs +++ b/zokrates_core/src/static_analysis/reducer/constants_writer.rs @@ -1,4 +1,4 @@ -// A folder to inline all constant definitions down to a single litteral and register them in the state for later use. +// A folder to inline all constant definitions down to a single literal and register them in the state for later use. use crate::static_analysis::reducer::{ constants_reader::ConstantsReader, reduce_function, ConstantDefinitions, Error, From 9efec3f5da3b297ab00293d7fdc4114029d3fd8c Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Sep 2021 20:15:42 +0200 Subject: [PATCH 60/78] update docs around field type --- changelogs/unreleased/1017-dark64 | 1 + zokrates_book/src/language/types.md | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 changelogs/unreleased/1017-dark64 diff --git a/changelogs/unreleased/1017-dark64 b/changelogs/unreleased/1017-dark64 new file mode 100644 index 000000000..59880b51d --- /dev/null +++ b/changelogs/unreleased/1017-dark64 @@ -0,0 +1 @@ +Add a book section on supported arithmetic operations for the `field` type \ No newline at end of file diff --git a/zokrates_book/src/language/types.md b/zokrates_book/src/language/types.md index 146de1296..af617258c 100644 --- a/zokrates_book/src/language/types.md +++ b/zokrates_book/src/language/types.md @@ -6,7 +6,7 @@ ZoKrates currently exposes two primitive types and two complex types: ### `field` -This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers. +This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](../toolbox/proving_schemes.md#curves) curve supported by Ethereum. @@ -16,7 +16,18 @@ While `field` values mostly behave like unsigned integers, one should keep in mi {{#include ../../../zokrates_cli/examples/book/field_overflow.zok}} ``` -Note that for field elements, the division operation multiplies the numerator with the denominator's inverse field element. The results coincide with integer divisions for cases with remainder 0, but differ otherwise. +#### Arithmetic operations + +| Symbol | Meaning | +| ------ | ------------------------------------------------ | +| `+` | Addition mod `p` | +| `-` | Subtraction mod `p` | +| `*` | Product mod `p` | +| `/` | Division (multiplication by the inverse) mod `p` | +| `**` | Power mod `p` | + +Note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers. +For field elements, the division operation multiplies the numerator with the denominator's inverse field element. The results coincide with integer divisions for cases with remainder 0, but differ otherwise. ### `bool` From d1a17e059ba5000b2f73cfa8d5029026f22390b7 Mon Sep 17 00:00:00 2001 From: dark64 Date: Fri, 24 Sep 2021 20:43:07 +0200 Subject: [PATCH 61/78] add changelog --- changelogs/unreleased/1015-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1015-dark64 diff --git a/changelogs/unreleased/1015-dark64 b/changelogs/unreleased/1015-dark64 new file mode 100644 index 000000000..3db3ee4fe --- /dev/null +++ b/changelogs/unreleased/1015-dark64 @@ -0,0 +1 @@ +Improve error message on unconstrained variable detection \ No newline at end of file From c502ecfec13ab7aaa77cc090f923b0e57e8e07cc Mon Sep 17 00:00:00 2001 From: Johnny Date: Sat, 25 Sep 2021 12:27:59 -0700 Subject: [PATCH 62/78] Fix typo --- zokrates_book/src/examples/sha256example.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_book/src/examples/sha256example.md b/zokrates_book/src/examples/sha256example.md index 0fd050695..e970d7dc2 100644 --- a/zokrates_book/src/examples/sha256example.md +++ b/zokrates_book/src/examples/sha256example.md @@ -87,7 +87,7 @@ Based on that Victor can run the setup phase and export a verifier smart contrac {{#include ../../../zokrates_cli/examples/book/sha256_tutorial/test.sh:18:19}} ``` -`setup` creates a `verifiation.key` file and a `proving.key` file. Victor gives the proving key to Peggy. +`setup` creates a `verification.key` file and a `proving.key` file. Victor gives the proving key to Peggy. `export-verifier` creates a `verifier.sol` contract that contains our verification key and a function `verifyTx`. Victor deploys this smart contract to the Ethereum network. From 73c5e1056952e2a357aa683831eddd151c2b9cf0 Mon Sep 17 00:00:00 2001 From: Johnny Date: Sat, 25 Sep 2021 12:32:20 -0700 Subject: [PATCH 63/78] Update introduction.md --- zokrates_book/src/introduction.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zokrates_book/src/introduction.md b/zokrates_book/src/introduction.md index 69bcc28df..ccc8dc23c 100644 --- a/zokrates_book/src/introduction.md +++ b/zokrates_book/src/introduction.md @@ -9,7 +9,7 @@ ZoKrates is a toolbox for zkSNARKs on Ethereum. It helps you use verifiable comp One particular family of ZKPs is described as zero-knowledge **S**uccinct **N**on-interactive **AR**guments of **K**nowledge, a.k.a. zkSNARKs. zkSNARKs are the most widely used zero-knowledge protocols, with the anonymous cryptocurrency Zcash and the smart-contract platform Ethereum among the notable early adopters. -For further details we refer the reader to some introductory material provided by the community: [[1]](https://z.cash/technology/zksnarks/),[[2]](https://medium.com/@VitalikButerin/zkSNARKs-under-the-hood-b33151a013f6), [[3]](https://blog.decentriq.ch/zk-SNARKs-primer-part-one/). +For further details we refer the reader to some introductory material provided by the community: [[1]](https://z.cash/technology/zksnarks/), [[2]](https://medium.com/@VitalikButerin/zkSNARKs-under-the-hood-b33151a013f6), [[3]](https://blog.decentriq.ch/zk-SNARKs-primer-part-one/). ## Motivation @@ -19,4 +19,4 @@ ZoKrates bridges this gap. It helps you create off-chain programs and link them ## License -ZoKrates is released under the GNU Lesser General Public License v3. \ No newline at end of file +ZoKrates is released under the GNU Lesser General Public License v3. From a514f400ae5eaeeeb7ae88e3d4f5656748e0e4a6 Mon Sep 17 00:00:00 2001 From: dark64 Date: Sun, 26 Sep 2021 16:47:55 +0200 Subject: [PATCH 64/78] Set output timeout for zokrates_js_test job --- .circleci/config.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 22d9d1b2a..b766ebcd3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -126,6 +126,7 @@ jobs: - checkout - run: name: Build + no_output_timeout: "30m" command: cd zokrates_js && npm run build:dev zokrates_js_test: docker: @@ -142,6 +143,7 @@ jobs: command: cargo clippy -- -D warnings - run: name: Run tests + no_output_timeout: "30m" command: npm run test cross_build: parameters: From 11d9dd939b31f664639f0a34b9e0b9ffe9f4ebd9 Mon Sep 17 00:00:00 2001 From: dark64 Date: Sun, 26 Sep 2021 16:52:51 +0200 Subject: [PATCH 65/78] fix integration tests --- zokrates_cli/tests/integration.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index 81ea370d4..2d47b8a03 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -410,14 +410,14 @@ mod integration { for p in glob("./examples/book/rng_tutorial/*").expect("Failed to read glob pattern") { let path = p.unwrap(); - std::fs::copy(path.clone(), tmp_base.join(path.file_name().unwrap())).unwrap(); + std::fs::hard_link(path.clone(), tmp_base.join(path.file_name().unwrap())).unwrap(); } let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); assert_cli::Assert::command(&[ "./test.sh", - env!("CARGO_BIN_EXE_zokrates"), + "../target/release/zokrates", stdlib.to_str().unwrap(), ]) .current_dir(tmp_base) @@ -433,14 +433,14 @@ mod integration { for p in glob("./examples/book/sha256_tutorial/*").expect("Failed to read glob pattern") { let path = p.unwrap(); - std::fs::copy(path.clone(), tmp_base.join(path.file_name().unwrap())).unwrap(); + std::fs::hard_link(path.clone(), tmp_base.join(path.file_name().unwrap())).unwrap(); } let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); assert_cli::Assert::command(&[ "./test.sh", - env!("CARGO_BIN_EXE_zokrates"), + "../target/release/zokrates", stdlib.to_str().unwrap(), ]) .current_dir(tmp_base) From 0d8d7549a21c97ceafdbe2b0d68039ab8e1b7438 Mon Sep 17 00:00:00 2001 From: dark64 Date: Sun, 26 Sep 2021 17:00:37 +0200 Subject: [PATCH 66/78] canonicalize paths --- zokrates_cli/tests/integration.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index 2d47b8a03..d2b02b8f7 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -414,10 +414,11 @@ mod integration { } let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); + let binary_path = std::fs::canonicalize("../target/release/zokrates").unwrap(); assert_cli::Assert::command(&[ "./test.sh", - "../target/release/zokrates", + binary_path.to_str().unwrap(), stdlib.to_str().unwrap(), ]) .current_dir(tmp_base) @@ -437,10 +438,11 @@ mod integration { } let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); + let binary_path = std::fs::canonicalize("../target/release/zokrates").unwrap(); assert_cli::Assert::command(&[ "./test.sh", - "../target/release/zokrates", + binary_path.to_str().unwrap(), stdlib.to_str().unwrap(), ]) .current_dir(tmp_base) From 2e2360c4c37ee3228bcb74b6390a788e7c4a31d4 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 27 Sep 2021 10:33:19 +0300 Subject: [PATCH 67/78] fix renaming --- zokrates_core/src/static_analysis/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 5a07f1437..d50c11fe8 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -26,7 +26,7 @@ use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_write_remover::VariableWriteRemover; use crate::compile::CompileConfig; use crate::ir::Prog; -use crate::static_analysis::constant_resolver::ConstantInliner; +use crate::static_analysis::constant_resolver::ConstantResolver; use crate::static_analysis::zir_propagation::ZirPropagator; use crate::typed_absy::{abi::Abi, TypedProgram}; use crate::zir::ZirProgram; @@ -95,7 +95,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> { // inline user-defined constants log::debug!("Static analyser: Inline constants"); - let r = ConstantInliner::inline(self); + let r = ConstantResolver::inline(self); log::trace!("\n{}", r); // isolate branches From 0fd1e0fb5ce5c903c79b70a6c572018398277ce5 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 27 Sep 2021 21:37:09 +0300 Subject: [PATCH 68/78] include g2 lib for gm17 --- zokrates_core/src/proof_system/scheme/gm17.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_core/src/proof_system/scheme/gm17.rs b/zokrates_core/src/proof_system/scheme/gm17.rs index 4bc389f2b..94c1aeaf2 100644 --- a/zokrates_core/src/proof_system/scheme/gm17.rs +++ b/zokrates_core/src/proof_system/scheme/gm17.rs @@ -49,7 +49,7 @@ impl Scheme for GM17 { impl SolidityCompatibleScheme for GM17 { fn export_solidity_verifier(vk: >::VerificationKey) -> String { let (mut template_text, solidity_pairing_lib) = - (String::from(CONTRACT_TEMPLATE), solidity_pairing_lib(false)); + (String::from(CONTRACT_TEMPLATE), solidity_pairing_lib(true)); // replace things in template let vk_regex = Regex::new(r#"(<%vk_[^i%]*%>)"#).unwrap(); From 04ff417d8756e0ac0a8efe8ab922214221a4c8cb Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 27 Sep 2021 22:32:12 +0200 Subject: [PATCH 69/78] update operators section --- changelogs/unreleased/1017-dark64 | 2 +- zokrates_book/.gitignore | 1 + zokrates_book/src/language/operators.md | 32 ++++++++++++------------- zokrates_book/src/language/types.md | 10 -------- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/changelogs/unreleased/1017-dark64 b/changelogs/unreleased/1017-dark64 index 59880b51d..efe6534ce 100644 --- a/changelogs/unreleased/1017-dark64 +++ b/changelogs/unreleased/1017-dark64 @@ -1 +1 @@ -Add a book section on supported arithmetic operations for the `field` type \ No newline at end of file +Make operators table more clear in the book \ No newline at end of file diff --git a/zokrates_book/.gitignore b/zokrates_book/.gitignore index 7585238ef..69745081c 100644 --- a/zokrates_book/.gitignore +++ b/zokrates_book/.gitignore @@ -1 +1,2 @@ book +mdbook \ No newline at end of file diff --git a/zokrates_book/src/language/operators.md b/zokrates_book/src/language/operators.md index b330ea79c..7ff348508 100644 --- a/zokrates_book/src/language/operators.md +++ b/zokrates_book/src/language/operators.md @@ -1,22 +1,22 @@ ## Operators -The following table lists the precedence and associativity of all operators. Operators are listed top to bottom, in ascending precedence. Operators in the same box group left to right. Operators are binary, unless the syntax is provided. +The following table lists the precedence and associativity of all operators. Operators are listed top to bottom, in ascending precedence. Operators in the same box group have the same precedence. Operators are binary, unless the syntax is provided. -| Operator | Description | Remarks | -|---------------------------------|-------------------------------------------------------------------|---------| -| `**`
| Power | [^1] | -| `+x`
`-x`
`!x`
| Positive
Negative
Negation
| | -| `*`
`/`
`%`
| Multiplication
Division
Remainder
| | -| `+`
`-`
| Addition
Subtraction
| | -| `<<`
`>>`
| Left shift
Right shift
| [^2] | -| `&` | Bitwise AND | | -| | | Bitwise OR | | -| `^` | Bitwise XOR | | -| `>=`
`>`
`<=`
`<`| Greater or equal
Greater
Lower or equal
Lower
| [^3] | -| `!=`
`==`
| Not Equal
Equal
| | -| `&&` | Boolean AND | | -| || | Boolean OR | | -| `if c then x else y fi` | Conditional expression | | +| Operator | Description | Field | Unsigned integers | Bool | Associativity | Remarks | +|---------------------------------|-------------------------------------------------------------------|-----------------------------------|-------------------|-----------------------------------|---------------|---------| +| `**`
| Power | ✓ | × | × | Left | [^1] | +| `+x`
`-x`
`!x`
| Positive
Negative
Negation
| ✓ | ✓ | ×
×
✓ | Right | | +| `*`
`/`
`%`
| Multiplication
Division
Remainder
| ✓

× | ✓ | × | Left | | +| `+`
`-`
| Addition
Subtraction
| ✓ | ✓ | × | Left | | +| `<<`
`>>`
| Left shift
Right shift
| × | ✓ | × | Left | [^2] | +| `&` | Bitwise AND | × | ✓ | × | Left | | +| | | Bitwise OR | × | ✓ | × | Left | | +| `^` | Bitwise XOR | × | ✓ | × | Left | | +| `>=`
`>`
`<=`
`<`| Greater or equal
Greater
Lower or equal
Lower
| ✓ | ✓ | × | Left | [^3] | +| `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | | +| `&&` | Boolean AND | × | × | ✓ | Left | | +| || | Boolean OR | × | × | ✓ | Left | | +| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | | [^1]: The exponent must be a compile-time constant of type `u32` diff --git a/zokrates_book/src/language/types.md b/zokrates_book/src/language/types.md index af617258c..b97c1913d 100644 --- a/zokrates_book/src/language/types.md +++ b/zokrates_book/src/language/types.md @@ -16,16 +16,6 @@ While `field` values mostly behave like unsigned integers, one should keep in mi {{#include ../../../zokrates_cli/examples/book/field_overflow.zok}} ``` -#### Arithmetic operations - -| Symbol | Meaning | -| ------ | ------------------------------------------------ | -| `+` | Addition mod `p` | -| `-` | Subtraction mod `p` | -| `*` | Product mod `p` | -| `/` | Division (multiplication by the inverse) mod `p` | -| `**` | Power mod `p` | - Note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers. For field elements, the division operation multiplies the numerator with the denominator's inverse field element. The results coincide with integer divisions for cases with remainder 0, but differ otherwise. From 38f0b4a83a673de390b9da0cdfd21d74b8337fa7 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 28 Sep 2021 15:12:42 +0300 Subject: [PATCH 70/78] add changelog --- changelogs/unreleased/1008-m1cm1c | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1008-m1cm1c diff --git a/changelogs/unreleased/1008-m1cm1c b/changelogs/unreleased/1008-m1cm1c new file mode 100644 index 000000000..47647fc4d --- /dev/null +++ b/changelogs/unreleased/1008-m1cm1c @@ -0,0 +1 @@ +Reduce the deployment cost of the g16 and pghr13 verifiers \ No newline at end of file From bd0373bf99adf78e260ed8e9a1bc92a71e6cfd8e Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 28 Sep 2021 15:57:49 +0300 Subject: [PATCH 71/78] add more details to the map, change header, format, remove spaces --- zokrates_book/src/language/operators.md | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/zokrates_book/src/language/operators.md b/zokrates_book/src/language/operators.md index 7ff348508..00a3ec2de 100644 --- a/zokrates_book/src/language/operators.md +++ b/zokrates_book/src/language/operators.md @@ -1,22 +1,22 @@ ## Operators -The following table lists the precedence and associativity of all operators. Operators are listed top to bottom, in ascending precedence. Operators in the same box group have the same precedence. Operators are binary, unless the syntax is provided. +The following table lists the precedence and associativity of all operators. Operators are listed top to bottom, in ascending precedence. Operators in the same cell have the same precedence. Operators are binary, unless the syntax is provided. -| Operator | Description | Field | Unsigned integers | Bool | Associativity | Remarks | -|---------------------------------|-------------------------------------------------------------------|-----------------------------------|-------------------|-----------------------------------|---------------|---------| -| `**`
| Power | ✓ | × | × | Left | [^1] | -| `+x`
`-x`
`!x`
| Positive
Negative
Negation
| ✓ | ✓ | ×
×
✓ | Right | | -| `*`
`/`
`%`
| Multiplication
Division
Remainder
| ✓

× | ✓ | × | Left | | -| `+`
`-`
| Addition
Subtraction
| ✓ | ✓ | × | Left | | -| `<<`
`>>`
| Left shift
Right shift
| × | ✓ | × | Left | [^2] | -| `&` | Bitwise AND | × | ✓ | × | Left | | -| | | Bitwise OR | × | ✓ | × | Left | | -| `^` | Bitwise XOR | × | ✓ | × | Left | | -| `>=`
`>`
`<=`
`<`| Greater or equal
Greater
Lower or equal
Lower
| ✓ | ✓ | × | Left | [^3] | -| `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | | -| `&&` | Boolean AND | × | × | ✓ | Left | | -| || | Boolean OR | × | × | ✓ | Left | | -| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | | +| Operator | Description | `field` | `u8/u16` `u32/u64` | `bool` | Associativity | Remarks | +|-------------------------------|---------------------------------------------------------------|------------------------------|-------------------------------|-----------------------------|---------------|---------| +| `**`
| Power | ✓ |   |   | Left | [^1] | +| `+x`
`-x`
`!x`
| Positive
Negative
Negation
| ✓

  | ✓

  |  
 
✓ | Right | | +| `*`
`/`
`%`
| Multiplication
Division
Remainder
| ✓

  | ✓

✓ |  
 
  | Left | | +| `+`
`-`
| Addition
Subtraction
| ✓ | ✓ |   | Left | | +| `<<`
`>>`
| Left shift
Right shift
|   | ✓ |   | Left | [^2] | +| `&` | Bitwise AND |   | ✓ |   | Left | | +| | | Bitwise OR |   | ✓ |   | Left | | +| `^` | Bitwise XOR |   | ✓ |   | Left | | +| `>=`
`>`
`<=`
`<` | Greater or equal
Greater
Lower or equal
Lower
| ✓ | ✓ |   | Left | [^3] | +| `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | | +| `&&` | Boolean AND |   |   | ✓ | Left | | +| || | Boolean OR |   |   | ✓ | Left | | +| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | | [^1]: The exponent must be a compile-time constant of type `u32` From 64c6a808ea65d68e61aea09151b5299b4fe87b50 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Sep 2021 15:51:27 +0200 Subject: [PATCH 72/78] fix table format --- zokrates_book/src/language/operators.md | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/zokrates_book/src/language/operators.md b/zokrates_book/src/language/operators.md index 00a3ec2de..b82d59b78 100644 --- a/zokrates_book/src/language/operators.md +++ b/zokrates_book/src/language/operators.md @@ -2,21 +2,21 @@ The following table lists the precedence and associativity of all operators. Operators are listed top to bottom, in ascending precedence. Operators in the same cell have the same precedence. Operators are binary, unless the syntax is provided. -| Operator | Description | `field` | `u8/u16` `u32/u64` | `bool` | Associativity | Remarks | -|-------------------------------|---------------------------------------------------------------|------------------------------|-------------------------------|-----------------------------|---------------|---------| -| `**`
| Power | ✓ |   |   | Left | [^1] | +| Operator | Description | `field` | `u8/u16` `u32/u64` | `bool` | Associativity | Remarks | +|----------------------------|------------------------------------------------------------|------------------------------|-------------------------------|-----------------------------|---------------|---------| +| `**`
| Power | ✓ |   |   | Left | [^1] | | `+x`
`-x`
`!x`
| Positive
Negative
Negation
| ✓

  | ✓

  |  
 
✓ | Right | | | `*`
`/`
`%`
| Multiplication
Division
Remainder
| ✓

  | ✓

✓ |  
 
  | Left | | -| `+`
`-`
| Addition
Subtraction
| ✓ | ✓ |   | Left | | -| `<<`
`>>`
| Left shift
Right shift
|   | ✓ |   | Left | [^2] | -| `&` | Bitwise AND |   | ✓ |   | Left | | -| | | Bitwise OR |   | ✓ |   | Left | | -| `^` | Bitwise XOR |   | ✓ |   | Left | | -| `>=`
`>`
`<=`
`<` | Greater or equal
Greater
Lower or equal
Lower
| ✓ | ✓ |   | Left | [^3] | -| `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | | -| `&&` | Boolean AND |   |   | ✓ | Left | | -| || | Boolean OR |   |   | ✓ | Left | | -| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | | +| `+`
`-`
| Addition
Subtraction
| ✓ | ✓ |   | Left | | +| `<<`
`>>`
| Left shift
Right shift
|   | ✓ |   | Left | [^2] | +| `&` | Bitwise AND |   | ✓ |   | Left | | +| | | Bitwise OR |   | ✓ |   | Left | | +| `^` | Bitwise XOR |   | ✓ |   | Left | | +| `>=`
`>`
`<=`
`<` | Greater or equal
Greater
Lower or equal
Lower
| ✓ | ✓ |   | Left | [^3] | +| `!=`
`==`
| Not Equal
Equal
| ✓ | ✓ | ✓ | Left | | +| `&&` | Boolean AND |   |   | ✓ | Left | | +| || | Boolean OR |   |   | ✓ | Left | | +| `if c then x else y fi` | Conditional expression | ✓ | ✓ | ✓ | Right | | [^1]: The exponent must be a compile-time constant of type `u32` From 95b4c9db576f33bec9c0ae8e65e2eb61693afcb6 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Sep 2021 17:53:03 +0200 Subject: [PATCH 73/78] remove g2 addition lib in pghr13 --- zokrates_core/src/proof_system/scheme/pghr13.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/zokrates_core/src/proof_system/scheme/pghr13.rs b/zokrates_core/src/proof_system/scheme/pghr13.rs index 567a3441b..b639eacf6 100644 --- a/zokrates_core/src/proof_system/scheme/pghr13.rs +++ b/zokrates_core/src/proof_system/scheme/pghr13.rs @@ -1,5 +1,5 @@ use crate::proof_system::scheme::{NonUniversalScheme, Scheme}; -use crate::proof_system::solidity::{solidity_pairing_lib, SOLIDITY_G2_ADDITION_LIB}; +use crate::proof_system::solidity::solidity_pairing_lib; use crate::proof_system::{G1Affine, G2Affine, SolidityCompatibleField, SolidityCompatibleScheme}; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -136,10 +136,7 @@ impl SolidityCompatibleScheme for PGHR13 { let re = Regex::new(r"(?P0[xX][0-9a-fA-F]{64})").unwrap(); template_text = re.replace_all(&template_text, "uint256($v)").to_string(); - format!( - "{}{}{}", - SOLIDITY_G2_ADDITION_LIB, solidity_pairing_lib, template_text - ) + format!("{}{}", solidity_pairing_lib, template_text) } } From e30678f558f8c3166159415f60bb24d34cb87409 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 29 Sep 2021 16:00:46 +0200 Subject: [PATCH 74/78] add serde default attribute to compile config to accept partial objects --- changelogs/unreleased/1023-dark64 | 1 + zokrates_core/src/compile.rs | 2 + zokrates_js/Cargo.lock | 456 ++++++++++++++++++++++++++++-- zokrates_js/package-lock.json | 2 +- 4 files changed, 440 insertions(+), 21 deletions(-) create mode 100644 changelogs/unreleased/1023-dark64 diff --git a/changelogs/unreleased/1023-dark64 b/changelogs/unreleased/1023-dark64 new file mode 100644 index 000000000..ec724536e --- /dev/null +++ b/changelogs/unreleased/1023-dark64 @@ -0,0 +1 @@ +Add serde `default` attribute to compile config to accept partial objects \ No newline at end of file diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index fc721cfe4..bbb2fdb13 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -164,7 +164,9 @@ impl fmt::Display for CompileErrorInner { #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct CompileConfig { + #[serde(default)] pub allow_unconstrained_variables: bool, + #[serde(default)] pub isolate_branches: bool, } diff --git a/zokrates_js/Cargo.lock b/zokrates_js/Cargo.lock index 2566ac336..ccf4d6742 100644 --- a/zokrates_js/Cargo.lock +++ b/zokrates_js/Cargo.lock @@ -17,6 +17,17 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee2a4ec343196209d6594e19543ae87a39f96d5534d7174822a3ad825dd6ed7e" +[[package]] +name = "ahash" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" +dependencies = [ + "getrandom 0.2.2", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.6.10" @@ -26,6 +37,219 @@ dependencies = [ "memchr", ] +[[package]] +name = "ark-bls12-377" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb89b97424403ec9cc22a1df0db748dd7396c9ba5fb5c71a6f0e10ae1d1a7449" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-r1cs-std", + "ark-std", +] + +[[package]] +name = "ark-bw6-761" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69ad8d74a8e083a59defc4a226a19759691337006d5c9397dbd793af9e406418" +dependencies = [ + "ark-bls12-377", + "ark-ec", + "ark-ff", + "ark-std", +] + +[[package]] +name = "ark-crypto-primitives" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74b83a7e125e5c611e4a997123effb2f02e3fbc66531dd77751d3016ee920741" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-nonnative-field", + "ark-r1cs-std", + "ark-relations", + "ark-snark", + "ark-std", + "blake2", + "derivative", + "digest 0.9.0", + "tracing", +] + +[[package]] +name = "ark-ec" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56006994f509d76fbce6f6ffe3108f7191b4f3754ecd00bbae7cac20ec05020" +dependencies = [ + "ark-ff", + "ark-serialize", + "ark-std", + "derivative", + "num-traits 0.2.12", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4d8802d40fce9212c5c09be08f75c4b3becc0c488e87f60fff787b01250ce33" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "derivative", + "num-traits 0.2.12", + "rustc_version", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e8cb28c2137af1ef058aa59616db3f7df67dbb70bf2be4ee6920008cc30d98c" +dependencies = [ + "quote 1.0.7", + "syn 1.0.34", +] + +[[package]] +name = "ark-ff-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b9c256a93a10ed9708c16a517d6dcfaba3d215c0d7fab44d29a9affefb5eeb8" +dependencies = [ + "num-bigint 0.4.2", + "num-traits 0.2.12", + "quote 1.0.7", + "syn 1.0.34", +] + +[[package]] +name = "ark-gm17" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c9085a6c89aa65178aa2718b2efb62fd7c4dc23fe25285204e30b56e4cbfcac" +dependencies = [ + "ark-crypto-primitives", + "ark-ec", + "ark-ff", + "ark-poly", + "ark-r1cs-std", + "ark-relations", + "ark-serialize", + "ark-std", + "derivative", + "tracing", +] + +[[package]] +name = "ark-nonnative-field" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17887af156e9911d1dba5b30d49256d508f82f6a4f765a6fad8b5c637b700353" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-r1cs-std", + "ark-relations", + "ark-std", + "derivative", + "num-bigint 0.4.2", + "num-integer", + "num-traits 0.2.12", + "tracing", +] + +[[package]] +name = "ark-poly" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6683d21645a2abb94034f6a14e708405e55d9597687952d54b2269922857a" +dependencies = [ + "ark-ff", + "ark-serialize", + "ark-std", + "derivative", + "hashbrown", +] + +[[package]] +name = "ark-r1cs-std" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a90fea2b84ae4443983d56540360ea004cab952292b7a6535798b6b9dcb7f41" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-relations", + "ark-std", + "derivative", + "num-bigint 0.4.2", + "num-traits 0.2.12", + "tracing", +] + +[[package]] +name = "ark-relations" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a42f124f8dfff2b0561143c0c7ea48d7f7dc8d2c4c1e87eca14a27430c653c0b" +dependencies = [ + "ark-ff", + "ark-std", + "tracing", +] + +[[package]] +name = "ark-serialize" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e9b59329dc9b92086b3dc619f31cef4a0c802f10829b575a3666d48a48387d" +dependencies = [ + "ark-serialize-derive", + "ark-std", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ac3d78c750b01f5df5b2e76d106ed31487a93b3868f14a7f0eb3a74f45e1d8a" +dependencies = [ + "proc-macro2 1.0.18", + "quote 1.0.7", + "syn 1.0.34", +] + +[[package]] +name = "ark-snark" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39da26432fe584b0010741299820145ec69180fe9ea18ddf96946932763624a1" +dependencies = [ + "ark-ff", + "ark-relations", + "ark-std", +] + +[[package]] +name = "ark-std" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5b856a29bea7b810858116a596beee3d20fc4c5aeb240e8e5a8bca4845a470" +dependencies = [ + "rand 0.7.3", + "rand_xorshift", +] + [[package]] name = "arrayvec" version = "0.4.12" @@ -88,6 +312,17 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f0dc55f2d8a1a85650ac47858bb001b4c0dd73d79e3c455a842925e68d29cd3" +[[package]] +name = "blake2" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" +dependencies = [ + "crypto-mac", + "digest 0.9.0", + "opaque-debug 0.3.0", +] + [[package]] name = "blake2-rfc_bellman_edition" version = "0.0.1" @@ -108,7 +343,7 @@ dependencies = [ "block-padding", "byte-tools", "byteorder", - "generic-array", + "generic-array 0.12.3", ] [[package]] @@ -184,6 +419,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "crypto-mac" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" +dependencies = [ + "generic-array 0.14.4", + "subtle", +] + [[package]] name = "csv" version = "1.1.3" @@ -206,13 +451,33 @@ dependencies = [ "memchr", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2 1.0.18", + "quote 1.0.7", + "syn 1.0.34", +] + [[package]] name = "digest" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" dependencies = [ - "generic-array", + "generic-array 0.12.3", +] + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array 0.14.4", ] [[package]] @@ -278,7 +543,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50c052fa6d4c2f12305ec364bfb8ef884836f3f61ea015b202372ff996d1ac4b" dependencies = [ - "num-bigint", + "num-bigint 0.2.6", "num-integer", "num-traits 0.2.12", "proc-macro2 1.0.18", @@ -391,6 +656,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "generic-array" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "501466ecc8a30d1d3b7fc9229b122b2ce8ed6e9d9223f1138d4babb253e51817" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.1.15" @@ -421,6 +696,15 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aaf91faf136cb47367fa430cd46e37a788775e7fa104f8b4bcb3861dc389b724" +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + [[package]] name = "hermit-abi" version = "0.1.15" @@ -531,11 +815,22 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74e768dff5fb39a41b3bcd30bb25cf989706c90d028d1ad71971987aa309d535" +dependencies = [ + "autocfg", + "num-integer", + "num-traits 0.2.12", +] + [[package]] name = "num-integer" -version = "0.1.43" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b" +checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" dependencies = [ "autocfg", "num-traits 0.2.12", @@ -588,9 +883,9 @@ checksum = "1ab52be62400ca80aa00285d25253d7f7c437b7375c4de678f5405d3afe82ca5" [[package]] name = "once_cell" -version = "1.4.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b631f7e854af39a1739f401cf34a8a013dfe09eac4fa4dba91e9768bd28168d" +checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" [[package]] name = "opaque-debug" @@ -598,6 +893,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "pairing_ce" version = "0.21.1" @@ -685,6 +986,12 @@ dependencies = [ "syn 1.0.34", ] +[[package]] +name = "pin-project-lite" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" + [[package]] name = "pin-utils" version = "0.1.0" @@ -802,6 +1109,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rand_xorshift" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77d416b86801d23dde1aa643023b775c3a462efc0ed96443add11546cdf1dca8" +dependencies = [ + "rand_core 0.5.1", +] + [[package]] name = "rdrand" version = "0.4.0" @@ -854,6 +1170,15 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" +[[package]] +name = "rustc_version" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" +dependencies = [ + "semver", +] + [[package]] name = "ryu" version = "1.0.5" @@ -869,7 +1194,7 @@ dependencies = [ "bellman_ce", "blake2-rfc_bellman_edition", "byteorder", - "digest", + "digest 0.8.1", "rand 0.4.6", "serde", "serde_derive", @@ -877,6 +1202,24 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser", +] + +[[package]] +name = "semver-parser" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" +dependencies = [ + "pest", +] + [[package]] name = "serde" version = "1.0.114" @@ -915,9 +1258,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d94d0bede923b3cea61f3f1ff57ff8cdfd77b400fb8f9998949e0cf04163df" dependencies = [ "block-buffer", - "digest", + "digest 0.8.1", "fake-simd", - "opaque-debug", + "opaque-debug 0.2.3", ] [[package]] @@ -927,9 +1270,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a256f46ea78a0c0d9ff00077504903ac881a1dafdc20da66545699e7776b3e69" dependencies = [ "block-buffer", - "digest", + "digest 0.8.1", "fake-simd", - "opaque-debug", + "opaque-debug 0.2.3", ] [[package]] @@ -947,6 +1290,12 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8" +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + [[package]] name = "syn" version = "0.15.44" @@ -999,6 +1348,35 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tracing" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84f96e095c0c82419687c20ddf5cb3eadb61f4e1405923c9dc8e53a1adacbda8" +dependencies = [ + "cfg-if 1.0.0", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98863d0dd09fa59a1b79c6750ad80dbda6b75f4e71c437a6a1a8cb91a8bcbd77" +dependencies = [ + "proc-macro2 1.0.18", + "quote 1.0.7", + "syn 1.0.34", +] + +[[package]] +name = "tracing-core" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46125608c26121c81b0c6d693eab5a420e416da7e43c426d2e8f7df8da8a3acf" + [[package]] name = "typed-arena" version = "1.7.0" @@ -1041,6 +1419,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ae116fef2b7fea257ed6440d3cfcff7f190865f170cdad00bb6465bf18ecba" +[[package]] +name = "version_check" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" + [[package]] name = "void" version = "1.0.2" @@ -1147,9 +1531,30 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "zeroize" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf68b08513768deaa790264a7fac27a58cbf2705cfcdc9448362229217d7e970" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdff2024a851a322b08f179173ae2ba620445aef1e838f0c196820eade4ae0c7" +dependencies = [ + "proc-macro2 1.0.18", + "quote 1.0.7", + "syn 1.0.34", + "synstructure", +] + [[package]] name = "zokrates_abi" -version = "0.1.4" +version = "0.1.5" dependencies = [ "serde", "serde_derive", @@ -1164,7 +1569,7 @@ version = "0.1.0" [[package]] name = "zokrates_core" -version = "0.6.4" +version = "0.6.6" dependencies = [ "bellman_ce", "bincode", @@ -1174,8 +1579,9 @@ dependencies = [ "getrandom 0.2.2", "hex", "lazy_static", + "log", "num", - "num-bigint", + "num-bigint 0.2.6", "pairing_ce", "rand 0.4.6", "rand 0.7.3", @@ -1192,10 +1598,20 @@ dependencies = [ [[package]] name = "zokrates_embed" -version = "0.1.3" +version = "0.1.4" dependencies = [ + "ark-bls12-377", + "ark-bw6-761", + "ark-crypto-primitives", + "ark-ec", + "ark-ff", + "ark-gm17", + "ark-r1cs-std", + "ark-relations", + "ark-std", "bellman_ce", "sapling-crypto_ce", + "zokrates_field", ] [[package]] @@ -1205,7 +1621,7 @@ dependencies = [ "bellman_ce", "bincode", "lazy_static", - "num-bigint", + "num-bigint 0.2.6", "num-integer", "num-traits 0.2.12", "serde", @@ -1216,7 +1632,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.0.33" +version = "1.0.35" dependencies = [ "console_error_panic_hook", "js-sys", @@ -1231,7 +1647,7 @@ dependencies = [ [[package]] name = "zokrates_parser" -version = "0.2.2" +version = "0.2.4" dependencies = [ "pest", "pest_derive", @@ -1239,7 +1655,7 @@ dependencies = [ [[package]] name = "zokrates_pest_ast" -version = "0.2.2" +version = "0.2.3" dependencies = [ "from-pest", "lazy_static", diff --git a/zokrates_js/package-lock.json b/zokrates_js/package-lock.json index 647b52688..1f5a38c26 100644 --- a/zokrates_js/package-lock.json +++ b/zokrates_js/package-lock.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.0.32", + "version": "1.0.35", "lockfileVersion": 1, "requires": true, "dependencies": { From 2fff71d49f3fe90c880c9ee265da3b5260d73d7e Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 29 Sep 2021 16:11:34 +0200 Subject: [PATCH 75/78] update wrapper.js --- zokrates_js/wrapper.js | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/zokrates_js/wrapper.js b/zokrates_js/wrapper.js index 636ca2e97..7ae941d9a 100644 --- a/zokrates_js/wrapper.js +++ b/zokrates_js/wrapper.js @@ -36,15 +36,12 @@ module.exports = (dep) => { return { compile: (source, options = {}) => { - const createConfig = (config) => ({ - allow_unconstrained_variables: false, - ...config - }); const { location = "main.zok", resolveCallback = () => null, config = {} } = options; + console.log(config); const callback = (currentLocation, importLocation) => { return resolveFromStdlib(currentLocation, importLocation) || resolveCallback(currentLocation, importLocation); }; - const { program, abi } = zokrates.compile(source, location, callback, createConfig(config)); + const { program, abi } = zokrates.compile(source, location, callback, config); return { program: new Uint8Array(program), abi From a0b5acf1e509133e9aa7e1658e4910139c712fbf Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 29 Sep 2021 16:12:08 +0200 Subject: [PATCH 76/78] remove console.log --- zokrates_js/wrapper.js | 1 - 1 file changed, 1 deletion(-) diff --git a/zokrates_js/wrapper.js b/zokrates_js/wrapper.js index 7ae941d9a..e85a56d51 100644 --- a/zokrates_js/wrapper.js +++ b/zokrates_js/wrapper.js @@ -37,7 +37,6 @@ module.exports = (dep) => { return { compile: (source, options = {}) => { const { location = "main.zok", resolveCallback = () => null, config = {} } = options; - console.log(config); const callback = (currentLocation, importLocation) => { return resolveFromStdlib(currentLocation, importLocation) || resolveCallback(currentLocation, importLocation); }; From 3cf08aa963f935643762e74aca9b2e93709a4af3 Mon Sep 17 00:00:00 2001 From: dark64 Date: Thu, 30 Sep 2021 15:46:51 +0200 Subject: [PATCH 77/78] remove changelog --- changelogs/unreleased/1023-dark64 | 1 - 1 file changed, 1 deletion(-) delete mode 100644 changelogs/unreleased/1023-dark64 diff --git a/changelogs/unreleased/1023-dark64 b/changelogs/unreleased/1023-dark64 deleted file mode 100644 index ec724536e..000000000 --- a/changelogs/unreleased/1023-dark64 +++ /dev/null @@ -1 +0,0 @@ -Add serde `default` attribute to compile config to accept partial objects \ No newline at end of file From 044456418a7db5de26980f5f5d5633a9d94a7ee2 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 4 Oct 2021 17:45:29 +0200 Subject: [PATCH 78/78] update changelog, bump versions --- CHANGELOG.md | 18 ++++++++++++++++++ Cargo.lock | 4 ++-- changelogs/unreleased/1008-m1cm1c | 1 - changelogs/unreleased/1013-schaeff | 1 - changelogs/unreleased/1015-dark64 | 1 - changelogs/unreleased/1017-dark64 | 1 - changelogs/unreleased/957-dark64 | 1 - changelogs/unreleased/974-schaeff | 1 - changelogs/unreleased/975-schaeff | 1 - changelogs/unreleased/977-dark64 | 1 - changelogs/unreleased/987-schaeff | 1 - changelogs/unreleased/992-dark64 | 1 - changelogs/unreleased/998-dark64 | 1 - zokrates_cli/Cargo.toml | 2 +- zokrates_core/Cargo.toml | 2 +- zokrates_js/Cargo.toml | 2 +- zokrates_js/package.json | 2 +- 17 files changed, 24 insertions(+), 17 deletions(-) delete mode 100644 changelogs/unreleased/1008-m1cm1c delete mode 100644 changelogs/unreleased/1013-schaeff delete mode 100644 changelogs/unreleased/1015-dark64 delete mode 100644 changelogs/unreleased/1017-dark64 delete mode 100644 changelogs/unreleased/957-dark64 delete mode 100644 changelogs/unreleased/974-schaeff delete mode 100644 changelogs/unreleased/975-schaeff delete mode 100644 changelogs/unreleased/977-dark64 delete mode 100644 changelogs/unreleased/987-schaeff delete mode 100644 changelogs/unreleased/992-dark64 delete mode 100644 changelogs/unreleased/998-dark64 diff --git a/CHANGELOG.md b/CHANGELOG.md index 418767fc0..94623389a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,24 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.7.7] - 2021-10-04 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.7.7 + +### Changes +- Reduce the deployment cost of the g16 and pghr13 verifiers (#1008, @m1cm1c) +- Make operators table more clear in the book (#1017, @dark64) +- Allow calls in constant definitions (#975, @schaeff) +- Handle out of bound accesses gracefully (#1013, @schaeff) +- Improve error message on unconstrained variable detection (#1015, @dark64) +- Apply propagation in ZIR (#957, @dark64) +- Fail on mistyped constants (#974, @schaeff) +- Graceful error handling on unconstrained variable detection (#977, @dark64) +- Fix incorrect propagation of spreads (#987, @schaeff) +- Add range semantics to docs (#992, @dark64) +- Fix invalid cast to `usize` which caused wrong values in 32-bit environments (#998, @dark64) + ## [0.7.6] - 2021-08-16 ### Release diff --git a/Cargo.lock b/Cargo.lock index 17e4488a3..84b07de18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2383,7 +2383,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.7.6" +version = "0.7.7" dependencies = [ "assert_cli", "bincode", @@ -2410,7 +2410,7 @@ version = "0.1.0" [[package]] name = "zokrates_core" -version = "0.6.6" +version = "0.6.7" dependencies = [ "ark-bls12-377", "ark-bn254", diff --git a/changelogs/unreleased/1008-m1cm1c b/changelogs/unreleased/1008-m1cm1c deleted file mode 100644 index 47647fc4d..000000000 --- a/changelogs/unreleased/1008-m1cm1c +++ /dev/null @@ -1 +0,0 @@ -Reduce the deployment cost of the g16 and pghr13 verifiers \ No newline at end of file diff --git a/changelogs/unreleased/1013-schaeff b/changelogs/unreleased/1013-schaeff deleted file mode 100644 index b47a580e2..000000000 --- a/changelogs/unreleased/1013-schaeff +++ /dev/null @@ -1 +0,0 @@ -Handle out of bound accesses gracefully \ No newline at end of file diff --git a/changelogs/unreleased/1015-dark64 b/changelogs/unreleased/1015-dark64 deleted file mode 100644 index 3db3ee4fe..000000000 --- a/changelogs/unreleased/1015-dark64 +++ /dev/null @@ -1 +0,0 @@ -Improve error message on unconstrained variable detection \ No newline at end of file diff --git a/changelogs/unreleased/1017-dark64 b/changelogs/unreleased/1017-dark64 deleted file mode 100644 index efe6534ce..000000000 --- a/changelogs/unreleased/1017-dark64 +++ /dev/null @@ -1 +0,0 @@ -Make operators table more clear in the book \ No newline at end of file diff --git a/changelogs/unreleased/957-dark64 b/changelogs/unreleased/957-dark64 deleted file mode 100644 index 2c4f7389d..000000000 --- a/changelogs/unreleased/957-dark64 +++ /dev/null @@ -1 +0,0 @@ -Apply propagation in ZIR \ No newline at end of file diff --git a/changelogs/unreleased/974-schaeff b/changelogs/unreleased/974-schaeff deleted file mode 100644 index 0f2622f47..000000000 --- a/changelogs/unreleased/974-schaeff +++ /dev/null @@ -1 +0,0 @@ -Fail on mistyped constants \ No newline at end of file diff --git a/changelogs/unreleased/975-schaeff b/changelogs/unreleased/975-schaeff deleted file mode 100644 index d0073f3cb..000000000 --- a/changelogs/unreleased/975-schaeff +++ /dev/null @@ -1 +0,0 @@ -Allow calls in constant definitions \ No newline at end of file diff --git a/changelogs/unreleased/977-dark64 b/changelogs/unreleased/977-dark64 deleted file mode 100644 index c29eed7dc..000000000 --- a/changelogs/unreleased/977-dark64 +++ /dev/null @@ -1 +0,0 @@ -Graceful error handling on unconstrained variable detection \ No newline at end of file diff --git a/changelogs/unreleased/987-schaeff b/changelogs/unreleased/987-schaeff deleted file mode 100644 index 8fe0d3fc6..000000000 --- a/changelogs/unreleased/987-schaeff +++ /dev/null @@ -1 +0,0 @@ -Fix incorrect propagation of spreads \ No newline at end of file diff --git a/changelogs/unreleased/992-dark64 b/changelogs/unreleased/992-dark64 deleted file mode 100644 index bdd50880b..000000000 --- a/changelogs/unreleased/992-dark64 +++ /dev/null @@ -1 +0,0 @@ -Add range semantics to docs \ No newline at end of file diff --git a/changelogs/unreleased/998-dark64 b/changelogs/unreleased/998-dark64 deleted file mode 100644 index d784d85c1..000000000 --- a/changelogs/unreleased/998-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix invalid cast to `usize` which caused wrong values in 32-bit environments \ No newline at end of file diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 177892a99..5c2b04cca 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_cli" -version = "0.7.6" +version = "0.7.7" authors = ["Jacob Eberhardt ", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index 41bc7f170..41d4a1056 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.6.6" +version = "0.6.7" edition = "2018" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index c2cb7e1c4..ef021eb03 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.0.35" +version = "1.0.36" authors = ["Darko Macesic"] edition = "2018" diff --git a/zokrates_js/package.json b/zokrates_js/package.json index de98ffe14..456a23e01 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -2,7 +2,7 @@ "name": "zokrates-js", "main": "index.js", "author": "Darko Macesic ", - "version": "1.0.35", + "version": "1.0.36", "keywords": [ "zokrates", "wasm-bindgen",