diff --git a/.github/workflows/js-format-check.yml b/.github/workflows/js-format-check.yml index f09fe5974..9c88c5cc9 100644 --- a/.github/workflows/js-format-check.yml +++ b/.github/workflows/js-format-check.yml @@ -6,6 +6,6 @@ jobs: steps: - uses: actions/checkout@v2 - name: Check format with prettier - uses: creyD/prettier_action@v4.2 + uses: creyD/prettier_action@v4.3 with: prettier_options: --check ./**/*.{js,ts,json} diff --git a/CHANGELOG.md b/CHANGELOG.md index 4188c1c34..9ea84c10a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.8.5] - 2023-03-28 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.8.5 + +### Changes +- Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) (#1283, @schaeff) +- Fix `radix-path` help message on `mpc init` subcommand (#1280, @dark64) +- Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair (#1277, @dark64) +- Show help when running `zokrates mpc` (#1275, @dark64) + ## [0.8.4] - 2023-01-31 ### Release diff --git a/Cargo.lock b/Cargo.lock index e21917722..e780f0273 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2906,7 +2906,7 @@ dependencies = [ [[package]] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" dependencies = [ "cfg-if 0.1.10", "csv", @@ -2956,7 +2956,7 @@ dependencies = [ [[package]] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" dependencies = [ "ark-bls12-377", "cfg-if 0.1.10", @@ -3004,7 +3004,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" dependencies = [ "assert_cli", "blake2 0.8.1", @@ -3064,7 +3064,7 @@ dependencies = [ [[package]] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3147,7 +3147,7 @@ dependencies = [ [[package]] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" dependencies = [ "ark-bls12-377", "num", @@ -3163,7 +3163,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" dependencies = [ "console_error_panic_hook", "getrandom", diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml index e347abcdf..f93dc7d8d 100644 --- a/zokrates_analysis/Cargo.toml +++ b/zokrates_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" edition = "2021" [features] diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index f4b81d8e1..0b834ef8a 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>( }) .collect(), )], - typed::TypedStatement::PushCallLog(..) => vec![], - typed::TypedStatement::PopCallLog => vec![], typed::TypedStatement::For(..) => unreachable!(), }; diff --git a/zokrates_analysis/src/lib.rs b/zokrates_analysis/src/lib.rs index c628e7283..539fe86cb 100644 --- a/zokrates_analysis/src/lib.rs +++ b/zokrates_analysis/src/lib.rs @@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>( let r = reduce_program(r).map_err(Error::from)?; log::trace!("\n{}", r); - log::debug!("Static analyser: Propagate"); - let r = Propagator::propagate(r)?; - log::trace!("\n{}", r); - log::debug!("Static analyser: Concretize structs"); let r = StructConcretizer::concretize(r); log::trace!("\n{}", r); diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index b7e5c0a17..7d77e86db 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -44,25 +44,16 @@ impl fmt::Display for Error { } } -#[derive(Debug)] -pub struct Propagator<'ast, 'a, T: Field> { +#[derive(Debug, Default)] +pub struct Propagator<'ast, T> { // constants keeps track of constant expressions // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is - constants: &'a mut Constants<'ast, T>, + constants: Constants<'ast, T>, } -impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { - pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self { - Propagator { constants } - } - +impl<'ast, T: Field> Propagator<'ast, T> { pub fn propagate(p: TypedProgram<'ast, T>) -> Result, Error> { - let mut constants = Constants::new(); - - Propagator { - constants: &mut constants, - } - .fold_program(p) + Propagator::default().fold_program(p) } // get a mutable reference to the constant corresponding to a given assignee if any, otherwise @@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { } } -impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { +impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { type Error = Error; fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result, Error> { @@ -629,8 +620,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { _ => Ok(vec![TypedStatement::Assertion(expr, err)]), } } - s @ TypedStatement::PushCallLog(..) => Ok(vec![s]), - s @ TypedStatement::PopCallLog => Ok(vec![s]), s => fold_statement(self, s), } } @@ -1502,7 +1491,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(5))) ); } @@ -1515,7 +1504,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); } @@ -1528,7 +1517,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(6))) ); } @@ -1541,7 +1530,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1554,15 +1543,14 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(8))) ); } #[test] fn left_shift() { - let mut constants = Constants::new(); - let mut propagator = Propagator::with_constants(&mut constants); + let mut propagator = Propagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::LeftShift( @@ -1607,8 +1595,7 @@ mod tests { #[test] fn right_shift() { - let mut constants = Constants::new(); - let mut propagator = Propagator::with_constants(&mut constants); + let mut propagator = Propagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::RightShift( @@ -1676,7 +1663,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(2))) ); } @@ -1691,7 +1678,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1713,7 +1700,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1735,18 +1722,15 @@ mod tests { BooleanExpression::Not(box BooleanExpression::identifier("a".into())); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_default.clone()), + Propagator::default().fold_boolean_expression(e_default.clone()), Ok(e_default) ); } @@ -1776,23 +1760,19 @@ mod tests { )); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_true), + Propagator::default().fold_boolean_expression(e_constant_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_false), + Propagator::default().fold_boolean_expression(e_constant_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_true), + Propagator::default().fold_boolean_expression(e_identifier_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_unchanged.clone()), + Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), Ok(e_identifier_unchanged) ); } @@ -1800,38 +1780,42 @@ mod tests { #[test] fn bool_eq() { assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(false), BooleanExpression::Value(false) - ))), + )) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(true), BooleanExpression::Value(true) - ))), + )) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(true), BooleanExpression::Value(false) - ))), + )) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(false), BooleanExpression::Value(true) - ))), + )) + ), Ok(BooleanExpression::Value(false)) ); } @@ -1933,33 +1917,27 @@ mod tests { )); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_true), + Propagator::default().fold_boolean_expression(e_constant_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_false), + Propagator::default().fold_boolean_expression(e_constant_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_true), + Propagator::default().fold_boolean_expression(e_identifier_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_unchanged.clone()), + Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), Ok(e_identifier_unchanged) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_non_canonical_true), + Propagator::default().fold_boolean_expression(e_non_canonical_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_non_canonical_false), + Propagator::default().fold_boolean_expression(e_non_canonical_false), Ok(BooleanExpression::Value(false)) ); } @@ -1977,13 +1955,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2001,13 +1977,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2025,13 +1999,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2049,13 +2021,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2065,67 +2035,75 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); } @@ -2135,67 +2113,75 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); } diff --git a/zokrates_analysis/src/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs index 4ee0d1359..f991f77fd 100644 --- a/zokrates_analysis/src/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -2,10 +2,11 @@ use crate::reducer::ConstantDefinitions; use zokrates_ast::typed::{ - folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier, - DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression, - StructExpression, StructExpressionInner, StructType, TupleExpression, TupleExpressionInner, - TupleType, TypedProgram, TypedSymbolDeclaration, UBitwidth, UExpression, UExpressionInner, + folder::*, identifier::FrameIdentifier, ArrayExpression, ArrayExpressionInner, ArrayType, + BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id, + Identifier, IdentifierExpression, StructExpression, StructExpressionInner, StructType, + TupleExpression, TupleExpressionInner, TupleType, TypedProgram, TypedSymbolDeclaration, + UBitwidth, UExpression, UExpressionInner, }; use zokrates_field::Field; @@ -61,7 +62,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { FieldElementExpression::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, + }, version, }, .. @@ -86,7 +91,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { BooleanExpression::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, + }, version, }, .. @@ -112,7 +121,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { UExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, + }, version, }, .. @@ -136,7 +149,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { ArrayExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, + }, version, }, .. @@ -162,7 +179,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { TupleExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, + }, version, }, .. @@ -188,7 +209,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { StructExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame: _, + }, version, }, .. diff --git a/zokrates_analysis/src/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs index d4e03d3d4..50a6d25ad 100644 --- a/zokrates_analysis/src/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -5,9 +5,9 @@ use crate::reducer::{ }; use std::collections::{BTreeMap, HashSet}; use zokrates_ast::typed::{ - result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed, - TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId, - TypedProgram, TypedSymbolDeclaration, UExpression, + result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol, + TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration, + UExpression, }; use zokrates_field::Field; @@ -118,11 +118,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { signature: DeclarationSignature::new().output(c.ty.clone()), }; - let mut inlined_wrapper = reduce_function( - wrapper, - ConcreteGenericsAssignment::default(), - &self.program, - )?; + let mut inlined_wrapper = reduce_function(wrapper, &self.program)?; if let TypedStatement::Return(expression) = inlined_wrapper.statements.pop().unwrap() diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 31f237e82..8d3727d3d 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -15,25 +15,24 @@ // ``` // // Becomes -// ``` -// # Call foo::<42> with a_0 := x -// n_0 = 42 -// a_1 = a_0 -// n_1 = n_0 -// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1 +// inputs: [a] +// arguments: [x] +// generics_bindings: [n = 42] +// statements: +// n = 42 +// a = a +// n = n +// return_expression: a // Notes: -// - The body of the function is in SSA form -// - The return value(s) are assigned to internal variables - -use crate::reducer::Output; -use crate::reducer::ShallowTransformer; -use crate::reducer::Versions; +// - The body of the function is *not* in SSA form use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType}; use zokrates_ast::typed::CoreIdentifier; -use zokrates_ast::typed::Identifier; + +use zokrates_ast::typed::TypedAssignee; +use zokrates_ast::typed::UBitwidth; use zokrates_ast::typed::{ ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr, Signature, Type, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, @@ -43,22 +42,12 @@ use zokrates_field::Field; pub enum InlineError<'ast, T> { Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>), - Flat( - FlatEmbed, - Vec, - Vec>, - Type<'ast, T>, - ), - NonConstant( - DeclarationFunctionKey<'ast, T>, - Vec>>, - Vec>, - Type<'ast, T>, - ), + Flat(FlatEmbed, Vec, Type<'ast, T>), + NonConstant, } fn get_canonical_function<'ast, T: Field>( - function_key: DeclarationFunctionKey<'ast, T>, + function_key: &DeclarationFunctionKey<'ast, T>, program: &TypedProgram<'ast, T>, ) -> TypedFunctionSymbolDeclaration<'ast, T> { let s = program @@ -66,30 +55,35 @@ fn get_canonical_function<'ast, T: Field>( .get(&function_key.module) .unwrap() .functions_iter() - .find(|d| d.key == function_key) + .find(|d| d.key == *function_key) .unwrap(); match &s.symbol { - TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), program), + TypedFunctionSymbol::There(key) => get_canonical_function(key, program), _ => s.clone(), } } -type InlineResult<'ast, T> = Result< - Output<(Vec>, TypedExpression<'ast, T>), Vec>>, - InlineError<'ast, T>, ->; +pub struct InlineValue<'ast, T> { + /// the pre-SSA input variables to assign the arguments to + pub input_variables: Vec>, + /// the pre-SSA statements for this call, including definition of the generic parameters + pub statements: Vec>, + /// the pre-SSA return value for this call + pub return_value: TypedExpression<'ast, T>, +} + +type InlineResult<'ast, T> = Result, InlineError<'ast, T>>; pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( - k: DeclarationFunctionKey<'ast, T>, - generics: Vec>>, - arguments: Vec>, - output: &E::Ty, + k: &DeclarationFunctionKey<'ast, T>, + generics: &[Option>], + arguments: &[TypedExpression<'ast, T>], + output_ty: &E::Ty, program: &TypedProgram<'ast, T>, - versions: &'a mut Versions<'ast>, ) -> InlineResult<'ast, T> { use zokrates_ast::typed::Typed; - let output_type = output.clone().into_type(); + let output_type = output_ty.clone().into_type(); // we try to get concrete values for explicit generics let generics_values: Vec> = generics @@ -103,36 +97,23 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .transpose() }) .collect::>() - .map_err(|_| { - InlineError::NonConstant( - k.clone(), - generics.clone(), - arguments.clone(), - output_type.clone(), - ) - })?; + .map_err(|_| InlineError::NonConstant)?; // we infer a signature based on inputs and outputs - // this is where we could handle explicit annotations let inferred_signature = Signature::new() - .generics(generics.clone()) + .generics(generics.to_vec().clone()) .inputs(arguments.iter().map(|a| a.get_type()).collect()) .output(output_type.clone()); - // we try to get concrete values for the whole signature. if this fails we should propagate again + // we try to get concrete values for the whole signature let inferred_signature = match ConcreteSignature::try_from(inferred_signature) { Ok(s) => s, Err(_) => { - return Err(InlineError::NonConstant( - k, - generics, - arguments, - output_type, - )); + return Err(InlineError::NonConstant); } }; - let decl = get_canonical_function(k.clone(), program); + let decl = get_canonical_function(k, program); // get an assignment of generics for this call site let assignment: ConcreteGenericsAssignment<'ast> = k @@ -154,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat( e, e.generics::(&assignment), - arguments.clone(), output_type, )), _ => unreachable!(), @@ -162,59 +142,38 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); - let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) { - Output::Complete(v) => (v, None), - Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)), - }; - - let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone()); - - let input_bindings: Vec> = ssa_f + let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| { + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::uint( + CoreIdentifier::from(identifier), + UBitwidth::B32, + )), + TypedExpression::from(UExpression::from(value)).into(), + ) + }); + + let input_variables: Vec> = f .arguments .into_iter() .zip(inferred_signature.inputs.clone()) .map(|(p, t)| ConcreteVariable::new(p.id.id, t, false)) - .zip(arguments.clone()) - .map(|(v, a)| TypedStatement::definition(Variable::from(v).into(), a)) + .map(Variable::from) .collect(); - let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f - .statements - .into_iter() + let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings + .chain(f.statements) .partition(|s| !matches!(s, TypedStatement::Return(..))); assert_eq!(returns.len(), 1); - let return_expression = match returns.pop().unwrap() { + let return_value = match returns.pop().unwrap() { TypedStatement::Return(e) => e, _ => unreachable!(), }; - let v: ConcreteVariable<'ast> = ConcreteVariable::new( - Identifier::from(CoreIdentifier::Call(0)).version( - *versions - .entry(CoreIdentifier::Call(0)) - .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0), - ), - *inferred_signature.output.clone(), - false, - ); - - let expression = TypedExpression::from(Variable::from(v.clone())); - - let output_binding = TypedStatement::definition(Variable::from(v).into(), return_expression); - - let pop_log = TypedStatement::PopCallLog; - - let statements: Vec<_> = std::iter::once(call_log) - .chain(input_bindings) - .chain(statements) - .chain(std::iter::once(output_binding)) - .chain(std::iter::once(pop_log)) - .collect(); - - Ok(incomplete_data - .map(|d| Output::Incomplete((statements.clone(), expression.clone()), d)) - .unwrap_or_else(|| Output::Complete((statements, expression)))) + Ok(InlineValue { + input_variables, + statements, + return_value, + }) } diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index ea6feb536..826cd6663 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -3,40 +3,42 @@ // - free of function calls (except for low level calls) thanks to inlining // - free of for-loops thanks to unrolling -// The process happens in two steps -// 1. Shallow SSA for the `main` function -// We turn the `main` function into SSA form, but ignoring function calls and for loops -// 2. Unroll and inline -// We go through the shallow-SSA program and -// - unroll loops -// - inline function calls. This includes applying shallow-ssa on the target function +// The process happens in a greedy way, starting from the main function +// For each statement: +// * put it in ssa form +// * propagate it +// * inline it (calling this process recursively) +// * propagate again + +// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach mod constants_reader; mod constants_writer; mod inline; mod shallow_ssa; +use self::inline::InlineValue; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; use zokrates_ast::typed::result_folder::*; -use zokrates_ast::typed::types::ConcreteGenericsAssignment; -use zokrates_ast::typed::types::GGenericsAssignment; +use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; -use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; - +use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::{ ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, - FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, - TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, - TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, + FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression, + TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, + TypedStatement, UExpression, UExpressionInner, }; +use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; use zokrates_field::Field; use self::constants_writer::ConstantsWriter; use self::shallow_ssa::ShallowTransformer; -use crate::propagation::{Constants, Propagator}; +use crate::propagation; +use crate::propagation::Propagator; use std::fmt; @@ -46,25 +48,15 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20); pub type ConstantDefinitions<'ast, T> = HashMap, TypedExpression<'ast, T>>; -// An SSA version map, giving access to the latest version number for each identifier -pub type Versions<'ast> = HashMap, usize>; - -// A container to represent whether more treatment must be applied to the function #[derive(Debug, PartialEq, Eq)] -pub enum Output { - Complete(U), - Incomplete(U, V), -} - -#[derive(Debug, Clone, PartialEq, Eq)] pub enum Error { Incompatible(String), GenericsInMain, - // TODO: give more details about what's blocking the progress - NoProgress, LoopTooLarge(u128), ConstantReduction(String, OwnedTypedModuleId), + NonConstant(String), Type(String), + Propagation(propagation::Error), } impl fmt::Display for Error { @@ -76,133 +68,36 @@ impl fmt::Display for Error { s ), Error::GenericsInMain => write!(f, "Cannot generate code for generic function"), - Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"), Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE), Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()), - Error::Type(message) => write!(f, "{}", message), + Error::NonConstant(s) => write!(f, "{}", s), + Error::Type(s) => write!(f, "{}", s), + Error::Propagation(e) => write!(f, "{}", e), } } } -#[derive(Debug, Default)] -struct Substitutions<'ast>(HashMap, HashMap>); - -impl<'ast> Substitutions<'ast> { - // create an equivalent substitution map where all paths - // are of length 1 - fn canonicalize(self) -> Self { - Substitutions( - self.0 - .into_iter() - .map(|(id, sub)| (id, Self::canonicalize_sub(sub))) - .collect(), - ) - } - - // canonicalize substitutions for a given id - fn canonicalize_sub(sub: HashMap) -> HashMap { - fn add_to_cache( - sub: &HashMap, - cache: HashMap, - k: usize, - ) -> HashMap { - match cache.contains_key(&k) { - // `k` is already in the cache, no changes to the cache - true => cache, - _ => match sub.get(&k) { - // `k` does not point to anything, no changes to the cache - None => cache, - // `k` points to some `v - Some(v) => { - // add `v` to the cache - let cache = add_to_cache(sub, cache, *v); - // `k` points to what `v` points to, or to `v` - let v = cache.get(v).cloned().unwrap_or(*v); - let mut cache = cache; - cache.insert(k, v); - cache - } - }, - } - } - - sub.keys() - .fold(HashMap::new(), |cache, k| add_to_cache(&sub, cache, *k)) - } -} - -struct Sub<'a, 'ast> { - substitutions: &'a Substitutions<'ast>, -} - -impl<'a, 'ast> Sub<'a, 'ast> { - fn new(substitutions: &'a Substitutions<'ast>) -> Self { - Self { substitutions } - } -} - -impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> { - fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> { - let version = self - .substitutions - .0 - .get(&id.id) - .map(|sub| sub.get(&id.version).cloned().unwrap_or(id.version)) - .unwrap_or(id.version); - id.version(version) - } -} - -fn register<'ast>( - substitutions: &mut Substitutions<'ast>, - substitute: &Versions<'ast>, - with: &Versions<'ast>, -) { - for (id, key, value) in substitute - .iter() - .filter_map(|(id, version)| with.get(id).map(|to| (id, version, to))) - .filter(|(_, key, value)| key != value) - { - let sub = substitutions.0.entry(id.clone()).or_default(); - - // redirect `k` to `v`, unless `v` is already redirected to `v0`, in which case we redirect to `v0` - - sub.insert(*key, *sub.get(value).unwrap_or(value)); +impl From for Error { + fn from(e: propagation::Error) -> Self { + Self::Propagation(e) } } #[derive(Debug)] struct Reducer<'ast, 'a, T> { - statement_buffer: Vec>, - for_loop_versions: Vec>, - for_loop_versions_after: Vec>, program: &'a TypedProgram<'ast, T>, - versions: &'a mut Versions<'ast>, - substitutions: &'a mut Substitutions<'ast>, - complete: bool, + propagator: Propagator<'ast, T>, + ssa: ShallowTransformer<'ast>, + statement_buffer: Vec>, } impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { - fn new( - program: &'a TypedProgram<'ast, T>, - versions: &'a mut Versions<'ast>, - substitutions: &'a mut Substitutions<'ast>, - for_loop_versions: Vec>, - ) -> Self { - // we reverse the vector as it's cheaper to `pop` than to take from - // the head - let mut for_loop_versions = for_loop_versions; - - for_loop_versions.reverse(); - + fn new(program: &'a TypedProgram<'ast, T>) -> Self { Reducer { + propagator: Propagator::default(), + ssa: ShallowTransformer::default(), statement_buffer: vec![], - for_loop_versions_after: vec![], - for_loop_versions, - substitutions, program, - versions, - complete: true, } } } @@ -210,6 +105,13 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { type Error = Error; + fn fold_parameter( + &mut self, + p: DeclarationParameter<'ast, T>, + ) -> Result, Self::Error> { + Ok(self.ssa.fold_parameter(p)) + } + fn fold_function_call_expression< E: Id<'ast, T> + From> + Expr<'ast, T> + FunctionCall<'ast, T>, >( @@ -217,65 +119,98 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ty: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, Self::Error> { - let generics = e + // generics are already in ssa form + + let generics: Vec<_> = e .generics .into_iter() - .map(|g| g.map(|g| self.fold_uint_expression(g)).transpose()) + .map(|g| { + g.map(|g| { + let g = self.propagator.fold_uint_expression(g)?; + let g = self.fold_uint_expression(g)?; + + self.propagator + .fold_uint_expression(g) + .map_err(Self::Error::from) + }) + .transpose() + }) .collect::>()?; - let arguments = e + // arguments are already in ssa form + + let arguments: Vec<_> = e .arguments .into_iter() - .map(|e| self.fold_expression(e)) + .map(|e| { + let e = self.propagator.fold_expression(e)?; + let e = self.fold_expression(e)?; + + self.propagator + .fold_expression(e) + .map_err(Self::Error::from) + }) .collect::>()?; - let res = inline_call::<_, E>( - e.function_key.clone(), - generics, - arguments, - ty, - self.program, - self.versions, - ); + self.ssa.push_call_frame(); + + let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program); + + let res = match res { + Ok(InlineValue { + input_variables, + statements, + return_value, + }) => { + // the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs + let input_bindings: Vec<_> = input_variables + .into_iter() + .zip(arguments) + .map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a)) + .collect(); + + let input_bindings = input_bindings + .into_iter() + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + self.statement_buffer.extend(input_bindings); + + let statements = statements + .into_iter() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); - match res { - Ok(Output::Complete((statements, expression))) => { - self.complete &= true; - self.statement_buffer.extend(statements); - Ok(FunctionCallOrExpression::Expression( - E::from(expression).into_inner(), - )) - } - Ok(Output::Incomplete((statements, expression), delta_for_loop_versions)) => { - self.complete = false; self.statement_buffer.extend(statements); - self.for_loop_versions_after.extend(delta_for_loop_versions); + + let return_value = self.ssa.fold_expression(return_value); + + let return_value = self.propagator.fold_expression(return_value)?; + + let return_value = self.fold_expression(return_value)?; + Ok(FunctionCallOrExpression::Expression( - E::from(expression.clone()).into_inner(), + E::from(return_value).into_inner(), )) } Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!( "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => { - self.complete = false; - - Ok(FunctionCallOrExpression::Expression(E::function_call( - key, generics, arguments, - ))) - } - Err(InlineError::Flat(embed, generics, arguments, output_type)) => { - let identifier = Identifier::from(CoreIdentifier::Call(0)).version( - *self - .versions - .entry(CoreIdentifier::Call(0)) - .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0), - ); + Err(InlineError::NonConstant) => Err(Error::NonConstant(format!( + "Generic parameters must be compile-time constants, found {}", + FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments) + ))), + Err(InlineError::Flat(embed, generics, output_type)) => { + let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); let var = Variable::immutable(identifier.clone(), output_type); - let v = var.clone().into(); + + let v: TypedAssignee<'ast, T> = var.clone().into(); self.statement_buffer .push(TypedStatement::embed_call_definition( @@ -286,7 +221,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { identifier, ))) } - } + }; + + self.ssa.pop_call_frame(); + + res } fn fold_block_expression>( @@ -325,74 +264,70 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result>, Self::Error> { let res = match s { TypedStatement::For(v, from, to, statements) => { - let versions_before = self.for_loop_versions.pop().unwrap(); - - match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(from), UExpressionInner::Value(to)) => { - let mut out_statements = vec![]; - - // get a fresh set of versions for all variables to use as a starting point inside the loop - self.versions.values_mut().for_each(|v| *v += 1); - - // add this set of versions to the substitution, pointing to the versions before the loop - register(self.substitutions, self.versions, &versions_before); - - // the versions after the loop are found by applying an offset of 1 to the versions before the loop - let versions_after = versions_before - .clone() - .into_iter() - .map(|(k, v)| (k, v + 1)) - .collect(); - - let mut transformer = ShallowTransformer::with_versions(self.versions); - - if to - from > MAX_FOR_LOOP_SIZE { - return Err(Error::LoopTooLarge(to.saturating_sub(*from))); - } - - for index in *from..*to { - let statements: Vec> = - std::iter::once(TypedStatement::definition( - v.clone().into(), - UExpression::from(index as u32).into(), - )) - .chain(statements.clone().into_iter()) - .flat_map(|s| transformer.fold_statement(s)) - .collect(); - - out_statements.extend(statements); - } - - let backups = transformer.for_loop_backups; - let blocked = transformer.blocked; - - // we know the final versions of the variables after full unrolling of the loop - // the versions after the loop need to point to these, so we add to the substitutions - register(self.substitutions, &versions_after, self.versions); - - // we may have found new for loops when unrolling this one, which means new backed up versions - // we insert these in our backup list and update our cursor - - self.for_loop_versions_after.extend(backups); + let from = self.ssa.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; - // if the ssa transform got blocked, the reduction is not complete - self.complete &= !blocked; + let to = self.ssa.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; - Ok(out_statements) - } - _ => { - let from = self.fold_uint_expression(from)?; - let to = self.fold_uint_expression(to)?; - self.complete = false; - self.for_loop_versions_after.push(versions_before); - Ok(vec![TypedStatement::For(v, from, to, statements)]) + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(from), UExpressionInner::Value(to)) + if to - from > MAX_FOR_LOOP_SIZE => + { + Err(Error::LoopTooLarge(to.saturating_sub(*from))) } - } + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from + ..*to) + .flat_map(|index| { + std::iter::once(TypedStatement::definition( + v.clone().into(), + UExpression::from(index as u32).into(), + )) + .chain(statements.clone()) + .map(|s| self.fold_statement(s)) + .collect::>() + }) + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>()), + _ => Err(Error::NonConstant(format!( + "Expected loop bounds to be constant, found {}..{}", + from, to + ))), + }? + } + s => { + let statements = self.ssa.fold_statement(s); + + let statements = statements + .into_iter() + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + let statements = statements + .map(|s| fold_statement(self, s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + let statements = statements + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + statements.collect() } - s => fold_statement(self, s), }; - res.map(|res| self.statement_buffer.drain(..).chain(res).collect()) + Ok(self.statement_buffer.drain(..).chain(res).collect()) } fn fold_array_expression_inner( @@ -402,18 +337,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { match e { ArrayExpressionInner::Slice(box array, box from, box to) => { + let array = self.ssa.fold_array_expression(array); + let array = self.propagator.fold_array_expression(array)?; let array = self.fold_array_expression(array)?; + let array = self.propagator.fold_array_expression(array)?; + + let from = self.ssa.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + + let to = self.ssa.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { Ok(ArrayExpressionInner::Slice(box array, box from, box to)) } - _ => { - self.complete = false; - Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - } + _ => Err(Error::NonConstant(format!( + "Slice bounds must be compile time constants, found {}", + ArrayExpressionInner::Slice(box array, box from, box to) + ))), } } _ => fold_array_expression_inner(self, array_ty, e), @@ -443,7 +389,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E match main_function.signature.generics.len() { 0 => { - let main_function = reduce_function(main_function, GGenericsAssignment::default(), &p)?; + let main_function = Reducer::new(&p).fold_function(main_function)?; Ok(TypedProgram { main: p.main.clone(), @@ -467,91 +413,11 @@ pub fn reduce_program(p: TypedProgram) -> Result, E fn reduce_function<'ast, T: Field>( f: TypedFunction<'ast, T>, - generics: ConcreteGenericsAssignment<'ast>, program: &TypedProgram<'ast, T>, ) -> Result, Error> { - let mut versions = Versions::default(); - - let mut constants = Constants::default(); - - let f = match ShallowTransformer::transform(f, &generics, &mut versions) { - Output::Complete(f) => Ok(f), - Output::Incomplete(new_f, new_for_loop_versions) => { - let mut for_loop_versions = new_for_loop_versions; - - let mut f = new_f; - - let mut substitutions = Substitutions::default(); - - let mut hash = None; - - loop { - let mut reducer = Reducer::new( - program, - &mut versions, - &mut substitutions, - for_loop_versions, - ); - - let new_f = TypedFunction { - statements: f - .statements - .into_iter() - .map(|s| reducer.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - ..f - }; - - assert!(reducer.for_loop_versions.is_empty()); - - match reducer.complete { - true => { - substitutions = substitutions.canonicalize(); - - let new_f = Sub::new(&substitutions).fold_function(new_f); - - let new_f = Propagator::with_constants(&mut constants) - .fold_function(new_f) - .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - break Ok(new_f); - } - false => { - for_loop_versions = reducer.for_loop_versions_after; - - let new_f = Sub::new(&substitutions).fold_function(new_f); - - f = Propagator::with_constants(&mut constants) - .fold_function(new_f) - .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - let new_hash = Some(compute_hash(&f)); + assert!(f.signature.generics.is_empty()); - if new_hash == hash { - break Err(Error::NoProgress); - } else { - hash = new_hash - } - } - } - } - } - }?; - - Propagator::with_constants(&mut constants) - .fold_function(f) - .map_err(|e| Error::Incompatible(format!("{}", e))) -} - -fn compute_hash(f: &TypedFunction) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut s = DefaultHasher::new(); - f.hash(&mut s); - s.finish() + Reducer::new(program).fold_function(f) } #[cfg(test)] @@ -588,14 +454,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // a_1 = a_0; - // # PUSH CALL to foo - // a_3 := a_1; // input binding - // #RETURN_AT_INDEX_0_0 := a_3; - // # POP CALL - // a_2 = #RETURN_AT_INDEX_0_0; - // return a_2; + // def main(field a_f0_v0) -> field { + // a_f0_v1 = a_f0_v0; // redef + // a_f1_v0 = a_f0_v1; // input binding + // a_f0_v2 = a_f1_v0; // output binding + // return a_f0_v2; // } let foo: TypedFunction = TypedFunction { @@ -691,30 +554,13 @@ mod tests { Variable::field_element(Identifier::from("a").version(1)).into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo").signature( - DeclarationSignature::new() - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - ), - GGenericsAssignment::default(), - ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(3)).into(), + Variable::field_element(Identifier::from("a").in_frame(1)).into(), FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), - TypedStatement::definition( - Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0)) - .into(), - FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::field_element(Identifier::from("a").version(2)).into(), - FieldElementExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .into(), + FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(), ), TypedStatement::Return( FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), @@ -763,14 +609,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // field[1] b_0 = [42]; - // # PUSH CALL to foo::<1> - // a_0 = b_0; - // #RETURN_AT_INDEX_0_0 := a_0; - // # POP CALL - // b_1 = #RETURN_AT_INDEX_0_0; - // return a_2 + b_1[0]; + // def main(field a_f0_v0) -> field { + // field[1] b_f0_v0 = [a_f0_v0]; + // a_f1_v0 = b_f0_v0; + // b_f0_v1 = a_f1_v0; + // return a_f0_v0 + b_f0_v1[0]; // } let foo_signature = DeclarationSignature::new() @@ -897,42 +740,19 @@ mod tests { .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), TypedStatement::definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(0), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpression::identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), - ArrayExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .annotate(Type::FieldElement, 1u32) - .into(), + ArrayExpression::identifier(Identifier::from("a").in_frame(1)) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Return( (FieldElementExpression::identifier("a".into()) @@ -987,14 +807,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // field[1] b_0 = [42]; - // # PUSH CALL to foo::<1> - // a_0 = b_0; - // #RETURN_AT_INDEX_0_0 := a_0; - // # POP CALL - // b_1 = #RETURN_AT_INDEX_0_0; - // return a_2 + b_1[0]; + // def main(field a) -> field { + // field[1] b = [a]; + // a_f1 = b; + // b_1 = a_f1; + // return a + b_1[0]; // } let foo_signature = DeclarationSignature::new() @@ -1125,47 +942,25 @@ mod tests { TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), + vec![FieldElementExpression::identifier(Identifier::from("a")).into()] + .into(), ) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), TypedStatement::definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(0), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpression::identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), - ArrayExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .annotate(Type::FieldElement, 1u32) - .into(), + ArrayExpression::identifier(Identifier::from("a").in_frame(1)) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Return( (FieldElementExpression::identifier("a".into()) @@ -1391,33 +1186,11 @@ mod tests { let expected_main = TypedFunction { arguments: vec![], - statements: vec![ - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "bar") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 2)] - .into_iter() - .collect(), - ), - ), - TypedStatement::PopCallLog, - TypedStatement::PopCallLog, - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) - .annotate(TupleType::new(vec![])) - .into(), - ), - ], + statements: vec![TypedStatement::Return( + TupleExpressionInner::Value(vec![]) + .annotate(TupleType::new(vec![])) + .into(), + )], signature: DeclarationSignature::new(), }; diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index a071a0446..aaec4fd56 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -1,7 +1,6 @@ -// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can -// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility. -// Function calls are also left unvisited -// Saving the indices is not required for function calls, as they cannot modify their environment +// The SSA transformation +// * introduces new versions if and only if we are assigning to an identifier +// * does not visit the statements of loops // Example: // def main(field a) -> field { @@ -19,178 +18,167 @@ // u32 n_0 = 42; // a_1 = a_0 + 1; // field b_0 = foo(a_1); // we keep the function call as is -// # versions: {n: 0, a: 1, b: 0} // for u32 i_0 in 0..n_0 { // // we keep the loop body as is // } // return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop // } +use std::collections::HashMap; + use zokrates_ast::typed::folder::*; -use zokrates_ast::typed::types::ConcreteGenericsAssignment; -use zokrates_ast::typed::types::Type; + use zokrates_ast::typed::*; use zokrates_field::Field; -use super::{Output, Versions}; - -pub struct ShallowTransformer<'ast, 'a> { - // version index for any variable name - pub versions: &'a mut Versions<'ast>, - // A backup of the versions before each for-loop - pub for_loop_backups: Vec>, - // whether all statements could be unrolled so far. Loops with variable bounds cannot. - pub blocked: bool, +// An SSA version map, giving access to the latest version number for each identifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Versions<'ast> { + map: HashMap, usize>>, } -impl<'ast, 'a> ShallowTransformer<'ast, 'a> { - pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self { - ShallowTransformer { - versions, - for_loop_backups: Vec::default(), - blocked: false, +impl<'ast> Default for Versions<'ast> { + fn default() -> Self { + // create a call frame at index 0 + Self { + map: vec![(0, Default::default())].into_iter().collect(), } } +} - // increase all versions by 1 and return the old versions - fn create_version_gap(&mut self) -> Versions<'ast> { - let ret = self.versions.clone(); - self.versions.values_mut().for_each(|v| *v += 1); - ret - } +#[derive(Debug, Default)] +pub struct ShallowTransformer<'ast> { + // version index for any variable name + pub versions: Versions<'ast>, + pub frames: Vec, + pub latest_frame: usize, +} - fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { - let version = *self - .versions +impl<'ast> ShallowTransformer<'ast> { + pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { + let frame = self.frame(); + + let frame_versions = self.versions.map.entry(frame).or_default(); + + let version = frame_versions .entry(c_id.clone()) .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0); // otherwise, we start from this version + .or_default(); // otherwise, we start from this version - Identifier::from(c_id).version(version) + Identifier::from(c_id.in_frame(frame)).version(*version) } fn issue_next_ssa_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { assert_eq!(v.id.version, 0); Variable { - id: self.issue_next_identifier(v.id.id), + id: self.issue_next_identifier(v.id.id.id), ..v } } - pub fn transform( - f: TypedFunction<'ast, T>, - generics: &ConcreteGenericsAssignment<'ast>, - versions: &'a mut Versions<'ast>, - ) -> Output, Vec>> { - let mut unroller = ShallowTransformer::with_versions(versions); + fn frame(&self) -> usize { + *self.frames.last().unwrap_or(&0) + } - let f = unroller.fold_function(f, generics); + pub fn push_call_frame(&mut self) { + self.latest_frame += 1; + self.frames.push(self.latest_frame); + self.versions + .map + .insert(self.latest_frame, Default::default()); + } - match unroller.blocked { - false => Output::Complete(f), - true => Output::Incomplete(f, unroller.for_loop_backups), - } + pub fn pop_call_frame(&mut self) { + let frame = self.frames.pop().unwrap(); + self.versions.map.remove(&frame); } - fn fold_function( + // fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers, + // but this should not be applied recursively to complex assignees + fn fold_assignee_no_ssa_increase( &mut self, - f: TypedFunction<'ast, T>, - generics: &ConcreteGenericsAssignment<'ast>, - ) -> TypedFunction<'ast, T> { - let mut f = f; - - f.statements = generics - .0 - .clone() - .into_iter() - .map(|(g, v)| { - TypedStatement::definition( - Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false) - .into(), - UExpression::from(v as u32).into(), - ) - }) - .chain(f.statements) - .collect(); + a: TypedAssignee<'ast, T>, + ) -> TypedAssignee<'ast, T> { + match a { + TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), + TypedAssignee::Select(box a, box index) => TypedAssignee::Select( + box self.fold_assignee_no_ssa_increase(a), + box self.fold_uint_expression(index), + ), + TypedAssignee::Member(box s, m) => { + TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m) + } + TypedAssignee::Element(box s, index) => { + TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index) + } + } + } +} - for arg in &f.arguments { - let _ = self.issue_next_identifier(arg.id.id.id.clone()); +impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { + fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { + for g in &f.signature.generics { + let generic_parameter = match g.as_ref().unwrap() { + DeclarationConstant::Generic(g) => g, + _ => unreachable!(), + }; + let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone())); } fold_function(self, f) } - fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + fn fold_parameter( + &mut self, + p: DeclarationParameter<'ast, T>, + ) -> DeclarationParameter<'ast, T> { + DeclarationParameter { + id: DeclarationVariable { + id: self.issue_next_identifier(p.id.id.id.id), + ..p.id + }, + ..p + } + } + + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { match a { + // create a new version for assignments to identifiers TypedAssignee::Identifier(v) => { let v = self.issue_next_ssa_variable(v); TypedAssignee::Identifier(self.fold_variable(v)) } - a => fold_assignee(self, a), + // otherwise, simply replace by the current version + a => self.fold_assignee_no_ssa_increase(a), } } -} -impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { - fn fold_assembly_statement( - &mut self, - s: TypedAssemblyStatement<'ast, T>, - ) -> Vec> { - match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = self.fold_expression(e); - let a = self.fold_assignee(a); - vec![TypedAssemblyStatement::Assignment(a, e)] - } - s => fold_assembly_statement(self, s), - } - } fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { match s { - TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => { - let e = self.fold_expression(e); - let a = self.fold_assignee(a); - vec![TypedStatement::definition(a, e)] - } - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - let embed_call = self.fold_embed_call(embed_call); - let assignee = self.fold_assignee(assignee); - vec![TypedStatement::embed_call_definition(assignee, embed_call)] - } + // only fold bounds of for loop statements TypedStatement::For(v, from, to, stats) => { let from = self.fold_uint_expression(from); let to = self.fold_uint_expression(to); - self.blocked = true; - let versions_before_loop = self.create_version_gap(); - self.for_loop_backups.push(versions_before_loop); vec![TypedStatement::For(v, from, to, stats)] } s => fold_statement(self, s), } } + // retrieve the latest version fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - let res = Identifier { - version: *self.versions.get(&(n.id)).unwrap_or(&0), - ..n - }; - res - } - - fn fold_function_call_expression< - E: Id<'ast, T> + From> + Expr<'ast, T> + FunctionCall<'ast, T>, - >( - &mut self, - ty: &E::Ty, - c: FunctionCallExpression<'ast, T, E>, - ) -> FunctionCallOrExpression<'ast, T, E> { - if !c.function_key.id.starts_with('_') { - self.blocked = true; - } - - fold_function_call_expression(self, ty, c) + let version = self + .versions + .map + .get(&self.frame()) + .unwrap() + .get(&n.id.id) + .cloned() + .unwrap_or(0); + + n.in_frame(self.frame()).version(version) } } @@ -203,36 +191,57 @@ mod tests { use super::*; #[test] - fn detect_non_constant_bound() { - let loops: Vec> = vec![TypedStatement::For( - Variable::new("i", Type::Uint(UBitwidth::B32), false), - UExpression::identifier("i".into()).annotate(UBitwidth::B32), - 2u32.into(), - vec![], - )]; + fn ignore_loop_content() { + // field foo = 0 + // u32 i = 4; + // for u32 i in i..2 { + // foo = 5; + // } - let statements = loops; + // should be left unchanged, as we do not visit the loop content nor the index variable let f = TypedFunction { arguments: vec![], - signature: DeclarationSignature::new(), - statements, + statements: vec![ + TypedStatement::definition( + TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))), + FieldElementExpression::Number(Bn128Field::from(4)).into(), + ), + TypedStatement::definition( + TypedAssignee::Identifier(Variable::uint( + Identifier::from("i"), + UBitwidth::B32, + )), + UExpression::from(0u32).into(), + ), + TypedStatement::For( + Variable::new("i", Type::Uint(UBitwidth::B32), false), + UExpression::identifier("i".into()).annotate(UBitwidth::B32), + 2u32.into(), + vec![TypedStatement::definition( + TypedAssignee::Identifier(Variable::field_element(Identifier::from( + "foo", + ))), + FieldElementExpression::Number(Bn128Field::from(5)).into(), + )], + ), + TypedStatement::Return( + TupleExpressionInner::Value(vec![]) + .annotate(TupleType::new(vec![])) + .into(), + ), + ], + signature: DeclarationSignature::default(), }; - match ShallowTransformer::transform( - f, - &ConcreteGenericsAssignment::default(), - &mut Versions::default(), - ) { - Output::Incomplete(..) => {} - _ => unreachable!(), - }; + let mut ssa = ShallowTransformer::default(); + + assert_eq!(ssa.fold_function(f.clone()), f); } #[test] fn definition() { - // field a - // a = 5 + // field a = 5 // a = 6 // a @@ -241,9 +250,7 @@ mod tests { // a_1 = 6 // a_1 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -283,17 +290,14 @@ mod tests { #[test] fn incremental_definition() { - // field a - // a = 5 + // field a = 5 // a = a + 1 // should be turned into // a_0 = 5 // a_1 = a_0 + 1 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -342,9 +346,7 @@ mod tests { // a_0 = 2 // a_1 = foo(a_0) - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -403,9 +405,7 @@ mod tests { // a_0 = [1, 1] // a_0[1] = 2 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), @@ -460,9 +460,7 @@ mod tests { // a_0 = [[0, 1], [2, 3]] // a_0 = [4, 5] - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32)); @@ -557,10 +555,10 @@ mod tests { mod for_loop { use super::*; - use zokrates_ast::typed::types::GGenericsAssignment; + #[test] fn treat_loop() { - // def main(field a) -> field { + // def main(field a) -> field { // u32 n = 42; // n = n; // a = a; @@ -575,24 +573,21 @@ mod tests { // return a; // } - // When called with K := 1, expected: + // expected: // def main(field a_0) -> field { - // u32 K = 1; // u32 n_0 = 42; // n_1 = n_0; // a_1 = a_0; - // # versions: {n: 1, a: 1, K: 0} // for u32 i_0 in n_1..n_1*n_1 { // a_0 = a_0; // } - // a_3 = a_2; - // # versions: {n: 2, a: 3, K: 1} - // for u32 i_0 in n_2..n_2*n_2 { + // a_2 = a_1; + // for u32 i_0 in n_1..n_1*n_1 { // a_0 = a_0; // } - // a_5 = a_4; - // return a_5; - // } # versions: {n: 3, a: 5, K: 2} + // a_3 = a_2; + // return a_3; + // } let f: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], @@ -642,32 +637,15 @@ mod tests { TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() - .generics(vec![Some( - GenericIdentifier::with_name("K").with_index(0).into(), - )]) .inputs(vec![DeclarationType::FieldElement]) .output(DeclarationType::FieldElement), }; - let mut versions = Versions::default(); - - let ssa = ShallowTransformer::transform( - f, - &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - &mut versions, - ); + let mut ssa = ShallowTransformer::default(); let expected = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ - TypedStatement::definition( - Variable::uint("K", UBitwidth::B32).into(), - TypedExpression::Uint(1u32.into()), - ), TypedStatement::definition( Variable::uint("n", UBitwidth::B32).into(), TypedExpression::Uint(42u32.into()), @@ -696,16 +674,16 @@ mod tests { )], ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(3)).into(), - FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), + Variable::field_element(Identifier::from("a").version(2)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), TypedStatement::For( Variable::uint("i", UBitwidth::B32), - UExpression::identifier(Identifier::from("n").version(2)) + UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), - UExpression::identifier(Identifier::from("n").version(2)) + UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32) - * UExpression::identifier(Identifier::from("n").version(2)) + * UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), vec![TypedStatement::definition( Variable::field_element("a").into(), @@ -713,47 +691,35 @@ mod tests { )], ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(5)).into(), - FieldElementExpression::identifier(Identifier::from("a").version(4)).into(), + Variable::field_element(Identifier::from("a").version(3)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), TypedStatement::Return( - FieldElementExpression::identifier(Identifier::from("a").version(5)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], signature: DeclarationSignature::new() - .generics(vec![Some( - GenericIdentifier::with_name("K").with_index(0).into(), - )]) .inputs(vec![DeclarationType::FieldElement]) .output(DeclarationType::FieldElement), }; - assert_eq!( - versions, - vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)] - .into_iter() - .collect::() - ); + let res = ssa.fold_function(f); - let expected = Output::Incomplete( - expected, - vec![ - vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)] - .into_iter() - .collect::(), - vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)] - .into_iter() - .collect::(), - ], + assert_eq!( + ssa.versions.map, + vec![( + 0, + vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect() + )] + .into_iter() + .collect() ); - assert_eq!(ssa, expected); + assert_eq!(res, expected); } } mod shadowing { - use zokrates_ast::typed::types::GGenericsAssignment; - use super::*; #[test] @@ -764,11 +730,11 @@ mod tests { // return; // } - // should become + // should become (only the field variable is affected as shadowing is taken care of in semantics already) - // def main(field a_0) { - // field a_1 = 42; - // bool a_2 = true; + // def main(field a_s0_v0) { + // field a_s0_v1 = 42; + // bool a_s1_v0 = true // return; // } @@ -780,7 +746,11 @@ mod tests { TypedExpression::Uint(42u32.into()), ), TypedStatement::definition( - Variable::boolean("a").into(), + Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow( + "a".into(), + 1, + ))) + .into(), BooleanExpression::Value(true).into(), ), TypedStatement::Return( @@ -789,9 +759,7 @@ mod tests { .into(), ), ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]), + signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]), }; let expected: TypedFunction = TypedFunction { @@ -802,7 +770,11 @@ mod tests { TypedExpression::Uint(42u32.into()), ), TypedStatement::definition( - Variable::boolean(Identifier::from("a").version(2)).into(), + Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow( + "a".into(), + 1, + ))) + .into(), BooleanExpression::Value(true).into(), ), TypedStatement::Return( @@ -811,121 +783,17 @@ mod tests { .into(), ), ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]), - }; - - let mut versions = Versions::default(); - - let ssa = - ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions); - - assert_eq!(ssa, Output::Complete(expected)); - } - - #[test] - fn next_scope() { - // def main(field a) { - // for u32 i in 0..1 { - // a = a + 1 - // field a = 42 - // } - // return a - // } - - // should become - - // def main(field a_0) { - // # versions: {a: 0} - // for u32 i in 0..1 { - // a_0 = a_0 - // field a_0 = 42 - // } - // return a_1 - // } - - let f: TypedFunction = TypedFunction { - arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![ - TypedStatement::For( - Variable::uint("i", UBitwidth::B32), - 0u32.into(), - 1u32.into(), - vec![ - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::identifier("a".into()).into(), - ), - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::Number(42usize.into()).into(), - ), - ], - ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![FieldElementExpression::identifier( - "a".into(), - ) - .into()]) - .annotate(TupleType::new(vec![Type::FieldElement])) - .into(), - ), - ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - }; - - let expected: TypedFunction = TypedFunction { - arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![ - TypedStatement::For( - Variable::uint("i", UBitwidth::B32), - 0u32.into(), - 1u32.into(), - vec![ - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::identifier(Identifier::from("a")).into(), - ), - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::Number(42usize.into()).into(), - ), - ], - ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![FieldElementExpression::identifier( - Identifier::from("a").version(1), - ) - .into()]) - .annotate(TupleType::new(vec![Type::FieldElement])) - .into(), - ), - ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), + signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]), }; - let mut versions = Versions::default(); - - let ssa = - ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions); + let ssa = ShallowTransformer::default().fold_function(f); - assert_eq!( - ssa, - Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()]) - ); + assert_eq!(ssa, expected); } } mod function_call { use super::*; - use zokrates_ast::typed::types::GGenericsAssignment; // test that function calls are left in #[test] fn treat_calls() { @@ -939,17 +807,12 @@ mod tests { // return a; // } - // When called with K := 1, expected: // def main(field a_0) -> field { - // K = 1; - // u32 n_0 = 42; - // n_1 = n_0; // a_1 = a_0; - // a_2 = foo::(a_1); - // n_2 = n_1; - // a_3 = a_2 * foo::(a_2); + // a_2 = foo::<42>(a_1); + // a_3 = a_2 * foo::<42>(a_2); // return a_3; - // } # versions: {n: 2, a: 3} + // } let f: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], @@ -1007,25 +870,9 @@ mod tests { .output(DeclarationType::FieldElement), }; - let mut versions = Versions::default(); - - let ssa = ShallowTransformer::transform( - f, - &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - &mut versions, - ); - let expected = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ - TypedStatement::definition( - Variable::uint("K", UBitwidth::B32).into(), - TypedExpression::Uint(1u32.into()), - ), TypedStatement::definition( Variable::uint("n", UBitwidth::B32).into(), TypedExpression::Uint(42u32.into()), @@ -1089,14 +936,23 @@ mod tests { .output(DeclarationType::FieldElement), }; + let mut ssa = ShallowTransformer::default(); + + let res = ssa.fold_function(f); + assert_eq!( - versions, - vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] - .into_iter() - .collect::() + ssa.versions.map, + vec![( + 0, + vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] + .into_iter() + .collect() + )] + .into_iter() + .collect() ); - assert_eq!(ssa, Output::Incomplete(expected, vec![],)); + assert_eq!(res, expected); } } } diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324b..60eb498c6 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" edition = "2021" [features] diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index d3e87fcd0..722dcbf87 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -4,6 +4,8 @@ use crate::typed::types::*; use crate::typed::*; use zokrates_field::Field; +use super::identifier::FrameIdentifier; + pub trait Fold<'ast, T: Field>: Sized { fn fold>(self, f: &mut F) -> Self; } @@ -128,11 +130,12 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - let id = match n.id { - CoreIdentifier::Constant(c) => { - CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)) - } - id => id, + let id = match n.id.id.clone() { + CoreIdentifier::Constant(c) => FrameIdentifier { + id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)), + frame: 0, + }, + _ => n.id, }; Identifier { id, ..n } @@ -528,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( ) -> Vec> { match s { TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - f.fold_assignee(a), - f.fold_expression(e), - )] + let e = f.fold_expression(e); + vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)] } TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { vec![TypedAssemblyStatement::Constraint( @@ -549,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( ) -> Vec> { let res = match s { TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)), - TypedStatement::Definition(a, e) => { - TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e)) + TypedStatement::Definition(a, rhs) => { + let rhs = f.fold_definition_rhs(rhs); + TypedStatement::Definition(f.fold_assignee(a), rhs) } TypedStatement::Assertion(e, error) => { TypedStatement::Assertion(f.fold_boolean_expression(e), error) @@ -573,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_assembly_statement(s)) .collect(), ), - s => s, }; vec![res] } diff --git a/zokrates_ast/src/typed/identifier.rs b/zokrates_ast/src/typed/identifier.rs index 772eb2bf2..d23a9b3b7 100644 --- a/zokrates_ast/src/typed/identifier.rs +++ b/zokrates_ast/src/typed/identifier.rs @@ -24,18 +24,49 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { } } -impl<'ast> From> for CoreIdentifier<'ast> { - fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> { - CoreIdentifier::Constant(s) +impl<'ast> FrameIdentifier<'ast> { + pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { + FrameIdentifier { frame, ..self } } } +impl<'ast> Identifier<'ast> { + pub fn in_frame(self, frame: usize) -> Identifier<'ast> { + Identifier { + id: self.id.in_frame(frame), + ..self + } + } +} + +impl<'ast> CoreIdentifier<'ast> { + pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { + FrameIdentifier { id: self, frame } + } +} + +impl<'ast> From> for FrameIdentifier<'ast> { + fn from(s: CanonicalConstantIdentifier<'ast>) -> FrameIdentifier<'ast> { + FrameIdentifier::from(CoreIdentifier::Constant(s)) + } +} + +/// A identifier for a variable in a given call frame +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct FrameIdentifier<'ast> { + /// the id of the variable + #[serde(borrow)] + pub id: CoreIdentifier<'ast>, + /// the frame of the variable + pub frame: usize, +} + /// A identifier for a variable #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Identifier<'ast> { /// the id of the variable #[serde(borrow)] - pub id: CoreIdentifier<'ast>, + pub id: FrameIdentifier<'ast>, /// the version of the variable, used after SSA transformation pub version: usize, } @@ -58,7 +89,7 @@ impl<'ast> fmt::Display for ShadowedIdentifier<'ast> { if self.shadow == 0 { write!(f, "{}", self.id) } else { - write!(f, "{}_{}", self.id, self.shadow) + write!(f, "{}_s{}", self.id, self.shadow) } } } @@ -68,20 +99,45 @@ impl<'ast> fmt::Display for Identifier<'ast> { if self.version == 0 { write!(f, "{}", self.id) } else { - write!(f, "{}_{}", self.id, self.version) + write!(f, "{}_v{}", self.id, self.version) + } + } +} + +impl<'ast> fmt::Display for FrameIdentifier<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.frame == 0 { + write!(f, "{}", self.id) + } else { + write!(f, "{}_f{}", self.id, self.frame) } } } impl<'ast> From> for Identifier<'ast> { fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> { - Identifier::from(CoreIdentifier::Constant(id)) + Identifier::from(FrameIdentifier::from(CoreIdentifier::Constant(id))) + } +} + +impl<'ast> From> for Identifier<'ast> { + fn from(id: FrameIdentifier<'ast>) -> Identifier<'ast> { + Identifier { id, version: 0 } } } impl<'ast> From> for Identifier<'ast> { fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> { - Identifier { id, version: 0 } + Identifier { + id: FrameIdentifier::from(id), + version: 0, + } + } +} + +impl<'ast> From> for FrameIdentifier<'ast> { + fn from(id: CoreIdentifier<'ast>) -> FrameIdentifier<'ast> { + FrameIdentifier { id, frame: 0 } } } @@ -107,6 +163,6 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { - Identifier::from(CoreIdentifier::from(id)) + Identifier::from(FrameIdentifier::from(CoreIdentifier::from(id))) } } diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index 83ada241e..bd000d12a 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -27,7 +27,7 @@ pub use self::types::{ UBitwidth, }; use self::types::{ConcreteArrayType, ConcreteStructType}; -use crate::typed::types::{ConcreteGenericsAssignment, IntoType}; +use crate::typed::types::IntoType; pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::marker::PhantomData; @@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { writeln!(f)?; - let mut tab = 0; - for s in &self.statements { - if let TypedStatement::PopCallLog = s { - tab -= 1; - }; - - s.fmt_indented(f, 1 + tab)?; - writeln!(f)?; - - if let TypedStatement::PushCallLog(..) = s { - tab += 1; - }; + writeln!(f, "{}", s)?; } writeln!(f, "}}")?; @@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> { Vec>, ), Log(FormatString, Vec>), - // Aux - PushCallLog( - DeclarationFunctionKey<'ast, T>, - ConcreteGenericsAssignment<'ast>, - ), - PopCallLog, Assembly(Vec>), } @@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> { } } -impl<'ast, T: fmt::Display> TypedStatement<'ast, T> { - fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result { - match self { - TypedStatement::For(variable, from, to, statements) => { - write!(f, "{}", "\t".repeat(depth))?; - writeln!(f, "for {} in {}..{} {{", variable, from, to)?; - for s in statements { - s.fmt_indented(f, depth + 1)?; - writeln!(f)?; - } - write!(f, "{}}}", "\t".repeat(depth)) - } - TypedStatement::Assembly(statements) => { - write!(f, "{}", "\t".repeat(depth))?; - writeln!(f, "asm {{")?; - for s in statements { - writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; - } - write!(f, "{}}}", "\t".repeat(depth)) - } - s => write!(f, "{}{}", "\t".repeat(depth), s), - } - } -} - impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { .collect::>() .join(", ") ), - TypedStatement::PushCallLog(ref key, ref generics) => write!( - f, - "// PUSH CALL TO {}/{}::<{}>", - key.module.display(), - key.id, - generics, - ), - TypedStatement::PopCallLog => write!(f, "// POP CALL",), TypedStatement::Assembly(ref statements) => { writeln!(f, "asm {{")?; for s in statements { diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index e4146c504..8ed911314 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -4,6 +4,8 @@ use crate::typed::types::*; use crate::typed::*; use zokrates_field::Field; +use super::identifier::FrameIdentifier; + pub trait ResultFold<'ast, T: Field>: Sized { fn fold>(self, f: &mut F) -> Result; } @@ -156,11 +158,12 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_name(&mut self, n: Identifier<'ast>) -> Result, Self::Error> { - let id = match n.id { - CoreIdentifier::Constant(c) => { - CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?) - } - id => id, + let id = match n.id.id.clone() { + CoreIdentifier::Constant(c) => FrameIdentifier { + id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?), + frame: 0, + }, + _ => n.id, }; Ok(Identifier { id, ..n }) @@ -529,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result>, F::Error> { Ok(match s { TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - f.fold_assignee(a)?, - f.fold_expression(e)?, - )] + let e = f.fold_expression(e)?; + vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)] } TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { vec![TypedAssemblyStatement::Constraint( @@ -551,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( let res = match s { TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?), TypedStatement::Definition(a, e) => { - TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?) + let rhs = f.fold_definition_rhs(e)?; + TypedStatement::Definition(f.fold_assignee(a)?, rhs) } TypedStatement::Assertion(e, error) => { TypedStatement::Assertion(f.fold_boolean_expression(e)?, error) @@ -583,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( .flatten() .collect(), ), - s => s, }; Ok(vec![res]) } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index a453fef3f..f2bf23d5b 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -241,13 +241,14 @@ impl<'ast, T: Field> From> for UExpression<'ast, T> fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { DeclarationConstant::Generic(g) => { - UExpression::identifier(CoreIdentifier::from(g).into()).annotate(UBitwidth::B32) + UExpression::identifier(Identifier::from(CoreIdentifier::from(g))) + .annotate(UBitwidth::B32) } DeclarationConstant::Concrete(v) => { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) } DeclarationConstant::Constant(v) => { - UExpression::identifier(CoreIdentifier::from(v).into()).annotate(UBitwidth::B32) + UExpression::identifier(FrameIdentifier::from(v).into()).annotate(UBitwidth::B32) } DeclarationConstant::Expression(e) => e.try_into().unwrap(), } @@ -1144,8 +1145,7 @@ pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq>( impl<'ast, T: Field> From> for UExpression<'ast, T> { fn from(c: CanonicalConstantIdentifier<'ast>) -> Self { - UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - .annotate(UBitwidth::B32) + UExpression::identifier(Identifier::from(FrameIdentifier::from(c))).annotate(UBitwidth::B32) } } @@ -1230,6 +1230,7 @@ pub use self::signature::{ try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature, }; +use super::identifier::FrameIdentifier; use super::{Id, ShadowedIdentifier}; pub mod signature { diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 3fa5b5feb..ccf98794d 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" authors = ["Jacob Eberhardt ", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index 795e81ed0..d6b596f35 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -15,7 +15,7 @@ def countDuplicates(field e11, field e12, field e21, field e22) -> field { duplicates = duplicates + e11 == e21 ? 1 : 0; duplicates = duplicates + e11 == e22 ? 1 : 0; duplicates = duplicates + e12 == e21 ? 1 : 0; - duplicates = duplicates + e12 == e21 ? 1 : 0; + duplicates = duplicates + e12 == e22 ? 1 : 0; duplicates = duplicates + e21 == e22 ? 1 : 0; return duplicates; } diff --git a/zokrates_cli/src/ops/mpc/init.rs b/zokrates_cli/src/ops/mpc/init.rs index eb7ba16e4..92a972482 100644 --- a/zokrates_cli/src/ops/mpc/init.rs +++ b/zokrates_cli/src/ops/mpc/init.rs @@ -24,8 +24,8 @@ pub fn subcommand() -> App<'static, 'static> { .arg( Arg::with_name("radix-path") .short("r") - .long("radix-dir") - .help("Path of the directory containing parameters for various 2^m circuit depths (phase1radix2m{0..=m})") + .long("radix-path") + .help("Path of the phase1radix2m{n} file") .value_name("PATH") .takes_value(true) .required(true), diff --git a/zokrates_cli/src/ops/mpc/mod.rs b/zokrates_cli/src/ops/mpc/mod.rs index 000d82023..b01a1ddc9 100644 --- a/zokrates_cli/src/ops/mpc/mod.rs +++ b/zokrates_cli/src/ops/mpc/mod.rs @@ -1,4 +1,4 @@ -use clap::{App, ArgMatches, SubCommand}; +use clap::{App, AppSettings, ArgMatches, SubCommand}; pub mod beacon; pub mod contribute; @@ -9,6 +9,7 @@ pub mod verify; pub fn subcommand() -> App<'static, 'static> { SubCommand::with_name("mpc") .about("Multi-party computation (MPC) protocol") + .setting(AppSettings::SubcommandRequiredElseHelp) .subcommands(vec![ init::subcommand().display_order(1), contribute::subcommand().display_order(2), diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index aa1964c52..2b2530047 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" edition = "2021" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index a4f7bb0c3..4e006eed3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1170,7 +1170,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let id = arg.id.value.id; let info = IdentifierInfo { - id: decl_v.id.id.clone(), + id: decl_v.id.id.id.clone(), ty, is_mutable, }; diff --git a/zokrates_core_test/tests/tests/call_ssa.json b/zokrates_core_test/tests/tests/call_ssa.json new file mode 100644 index 000000000..ae021f62d --- /dev/null +++ b/zokrates_core_test/tests/tests/call_ssa.json @@ -0,0 +1,15 @@ +{ + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": "4" + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/call_ssa.zok b/zokrates_core_test/tests/tests/call_ssa.zok new file mode 100644 index 000000000..dad61d190 --- /dev/null +++ b/zokrates_core_test/tests/tests/call_ssa.zok @@ -0,0 +1,11 @@ +// main should be x -> x + 4 + +def foo(field mut a) -> field { + a = a + 1; + return a + 1; +} + +def main(field mut a) -> field { + a = foo(a + 1); + return a + 1; +} \ No newline at end of file diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 77aaff662..270436685 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" edition = "2021" [features] default = ["bellman", "ark"] -bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman", "zokrates_ast/bellman"] -ark = ["ark-bls12-377", "zokrates_embed/ark", "zokrates_ast/ark"] +bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman", "zokrates_ast/bellman", "zokrates_analysis/bellman"] +ark = ["ark-bls12-377", "zokrates_embed/ark", "zokrates_ast/ark", "zokrates_analysis/ark"] [dependencies] zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 0c86329b9..02374289b 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" authors = ["Darko Macesic"] edition = "2018" diff --git a/zokrates_js/index.js b/zokrates_js/index.js index a6d689f8e..55a7dbda6 100644 --- a/zokrates_js/index.js +++ b/zokrates_js/index.js @@ -68,13 +68,25 @@ const initialize = async () => { return result; }, setup: (program, entropy, options) => { - return wasmExports.setup(program, entropy, options); + const ptr = wasmExports.setup(program, entropy, options); + const result = { + vk: ptr.vk(), + pk: ptr.pk(), + }; + ptr.free(); + return result; }, universalSetup: (curve, size, entropy) => { return wasmExports.universal_setup(curve, size, entropy); }, setupWithSrs: (srs, program, options) => { - return wasmExports.setup_with_srs(srs, program, options); + const ptr = wasmExports.setup_with_srs(srs, program, options); + const result = { + vk: ptr.vk(), + pk: ptr.pk(), + }; + ptr.free(); + return result; }, generateProof: (program, witness, provingKey, entropy, options) => { return wasmExports.generate_proof( diff --git a/zokrates_js/package-lock.json b/zokrates_js/package-lock.json index 00ecd023a..0821df1a2 100644 --- a/zokrates_js/package-lock.json +++ b/zokrates_js/package-lock.json @@ -1,12 +1,12 @@ { "name": "zokrates-js", - "version": "1.1.4", + "version": "1.1.5", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "zokrates-js", - "version": "1.1.4", + "version": "1.1.5", "license": "GPLv3", "dependencies": { "pako": "^2.1.0" diff --git a/zokrates_js/package.json b/zokrates_js/package.json index 7617f20f3..02b6e919e 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.1.5", + "version": "1.1.6", "module": "index.js", "main": "index-node.js", "description": "JavaScript bindings for ZoKrates", diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 69513ea00..2b4dccebc 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -53,6 +53,7 @@ impl CompilationResult { arr.copy_from(&self.program); arr } + pub fn abi(&self) -> JsValue { JsValue::from_serde(&self.abi).unwrap() } @@ -88,9 +89,11 @@ impl ComputationResult { pub fn witness(&self) -> JsValue { JsValue::from_str(&self.witness) } + pub fn output(&self) -> JsValue { JsValue::from_str(&self.output) } + pub fn snarkjs_witness(&self) -> Option { self.snarkjs_witness.as_ref().map(|w| { let arr = js_sys::Uint8Array::new_with_length(w.len() as u32); @@ -100,6 +103,25 @@ impl ComputationResult { } } +#[wasm_bindgen] +pub struct Keypair { + vk: JsValue, + pk: Vec, +} + +#[wasm_bindgen] +impl Keypair { + pub fn vk(&self) -> JsValue { + self.vk.to_owned() + } + + pub fn pk(&self) -> js_sys::Uint8Array { + let arr = js_sys::Uint8Array::new_with_length(self.pk.len() as u32); + arr.copy_from(&self.pk); + arr + } +} + pub struct JsResolver<'a> { callback: &'a js_sys::Function, } @@ -204,6 +226,7 @@ impl<'a> Write for LogWriter<'a> { fn write(&mut self, buf: &[u8]) -> std::io::Result { self.buf.write(buf) } + fn flush(&mut self) -> std::io::Result<()> { self.callback .call1( @@ -352,10 +375,13 @@ mod internal { >( program: ir::Prog, rng: &mut R, - ) -> JsValue { + ) -> Keypair { let keypair = B::setup(program, rng); let tagged_keypair = TaggedKeypair::::new(keypair); - JsValue::from_serde(&tagged_keypair).unwrap() + Keypair { + vk: JsValue::from_serde(&tagged_keypair.vk).unwrap(), + pk: tagged_keypair.pk, + } } pub fn setup_universal< @@ -367,9 +393,13 @@ mod internal { >( srs: &[u8], program: ir::ProgIterator<'a, T, I>, - ) -> Result { + ) -> Result { let keypair = B::setup(srs.to_vec(), program).map_err(|e| JsValue::from_str(&e))?; - Ok(JsValue::from_serde(&TaggedKeypair::::new(keypair)).unwrap()) + let tagged_keypair = TaggedKeypair::::new(keypair); + Ok(Keypair { + vk: JsValue::from_serde(&tagged_keypair.vk).unwrap(), + pk: tagged_keypair.pk, + }) } pub fn universal_setup_of_size< @@ -528,7 +558,7 @@ pub fn export_solidity_verifier(vk: JsValue) -> Result { } #[wasm_bindgen] -pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result { +pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result { let options: serde_json::Value = options.into_serde().unwrap(); let backend = BackendParameter::try_from( @@ -597,7 +627,7 @@ pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result Result { +pub fn setup_with_srs(srs: &[u8], program: &[u8], options: JsValue) -> Result { let options: serde_json::Value = options.into_serde().unwrap(); let scheme = SchemeParameter::try_from(