diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d1ea0251b..590c9814a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,24 +1,27 @@ name: Rust -on: - push: - branches: [ "master" ] - pull_request: - branches: [ "master" ] - -env: - CARGO_TERM_COLOR: always +on: [push, pull_request] jobs: - build: + check_n_test: + name: cargo check & test + uses: noir-lang/.github/.github/workflows/rust-test.yml@main - runs-on: ubuntu-latest + clippy: + name: cargo clippy + uses: noir-lang/.github/.github/workflows/rust-clippy.yml@main + format: + name: cargo fmt + uses: noir-lang/.github/.github/workflows/rust-format.yml@main + + spellcheck: + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Build - run: cargo build --verbose - - name: Clippy - run: cargo clippy --verbose - - name: Run tests - run: cargo test --verbose + - uses: actions/checkout@v3 + - uses: streetsidesoftware/cspell-action@v2 + with: + files: | + **/*.{md,rs} + incremental_files_only: true # Run this action on files which have changed in PR + strict: false # Do not fail, if a spelling mistake is found (This can be annoying for contributors) diff --git a/CHANGELOG.md b/CHANGELOG.md index e64cad60c..5e40fadda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,14 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [0.4.1] - 2023-02-08 ### Added ### Fixed +- Removed duplicated logic in match branch + +### Changed + +### Removed + +## [0.4.0] - 2023-02-08 + +### Added + +- Add log directive +- Expose `acir_field` through `acir` crate +- Add permutation directive +- Add preprocess methods to ACVM interface + +### Fixed + ### Changed +- Changed spellings of many functions to be correct using spellchecker + ### Removed ## [0.3.1] - 2023-01-18 @@ -42,7 +61,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - XOR, Range and AND gates are no longer special case. They are now another opcode in the GadgetCall - Move fallback module to `stdlib` -- optimiser code and any other passes will live in acvm. acir is solely for defining the IR now. +- Optimizer code and any other passes will live in acvm. acir is solely for defining the IR now. - ACIR passes now live under the compiler parent module - Moved opcode module in acir crate to circuit/opcode - Rename GadgetCall to BlackBoxFuncCall diff --git a/README.md b/README.md index 7b38a27a0..49154bdc3 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # ACIR - Abstract Circuit Intermediate Representation -ACIR is an NP complete language that generalises R1CS and arithmetic circuits while not losing proving system specific optimisations through the use of black box functions. +ACIR is an NP complete language that generalizes R1CS and arithmetic circuits while not losing proving system specific optimizations through the use of black box functions. # ACVM - Abstract Circuit Virtual Machine This can be seen as the ACIR compiler. It will take an ACIR instance and convert it to the format required -by a particular proving system to create a proof. \ No newline at end of file +by a particular proving system to create a proof. diff --git a/acir/Cargo.toml b/acir/Cargo.toml index 33b042ba6..306d2f49a 100644 --- a/acir/Cargo.toml +++ b/acir/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "acir" -version = "0.3.1" +version = "0.4.1" authors = ["Kevaundray Wedderburn "] edition = "2021" license = "MIT" @@ -9,7 +9,7 @@ description = "ACIR is the IR that the VM processes, it is analogous to LLVM IR" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -acir_field = { version = "0.3.1", path = "../acir_field" } +acir_field = { version = "0.4.1", path = "../acir_field" } serde = { version = "1.0.136", features = ["derive"] } rmp-serde = "1.1.0" flate2 = "1.0.24" @@ -18,3 +18,7 @@ flate2 = "1.0.24" serde_json = "1.0" strum = "0.24" strum_macros = "0.24" + +[features] +bn254 = ["acir_field/bn254"] +bls12_381 = ["acir_field/bls12_381"] diff --git a/acir/src/circuit/blackbox_functions.rs b/acir/src/circuit/black_box_functions.rs similarity index 100% rename from acir/src/circuit/blackbox_functions.rs rename to acir/src/circuit/black_box_functions.rs diff --git a/acir/src/circuit/directives.rs b/acir/src/circuit/directives.rs index c3d8e2cd0..f3f5e6aca 100644 --- a/acir/src/circuit/directives.rs +++ b/acir/src/circuit/directives.rs @@ -2,7 +2,7 @@ use std::io::{Read, Write}; use crate::{ native_types::{Expression, Witness}, - serialisation::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}, + serialization::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}, }; use serde::{Deserialize, Serialize}; @@ -48,6 +48,16 @@ pub enum Directive { b: Vec, radix: u32, }, + + // Sort directive, using a sorting network + // This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted according to sort_by + PermutationSort { + inputs: Vec>, // Array of tuples to sort + tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc.. + bits: Vec, // control bits of the network which permutes the inputs into its sorted version + sort_by: Vec, // specify primary index to sort by, then the secondary,... For instance, if tuple is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai. + }, + Log(LogInfo), } impl Directive { @@ -58,6 +68,8 @@ impl Directive { Directive::Truncate { .. } => "truncate", Directive::OddRange { .. } => "odd_range", Directive::ToRadix { .. } => "to_radix", + Directive::PermutationSort { .. } => "permutation_sort", + Directive::Log { .. } => "log", } } fn to_u16(&self) -> u16 { @@ -67,6 +79,8 @@ impl Directive { Directive::Truncate { .. } => 2, Directive::OddRange { .. } => 3, Directive::ToRadix { .. } => 4, + Directive::Log { .. } => 5, + Directive::PermutationSort { .. } => 6, } } @@ -116,6 +130,39 @@ impl Directive { } write_u32(&mut writer, *radix)?; } + Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + } => { + write_u32(&mut writer, *tuple)?; + write_u32(&mut writer, a.len() as u32)?; + for e in a { + for i in 0..*tuple { + e[i as usize].write(&mut writer)?; + } + } + write_u32(&mut writer, bits.len() as u32)?; + for b in bits { + write_u32(&mut writer, b.witness_index())?; + } + write_u32(&mut writer, sort_by.len() as u32)?; + for i in sort_by { + write_u32(&mut writer, *i)?; + } + } + Directive::Log(info) => match info { + LogInfo::FinalizedOutput(output_string) => { + write_bytes(&mut writer, output_string.as_bytes())?; + } + LogInfo::WitnessOutput(witnesses) => { + write_u32(&mut writer, witnesses.len() as u32)?; + for w in witnesses { + write_u32(&mut writer, w.witness_index())?; + } + } + }, }; Ok(()) @@ -178,14 +225,53 @@ impl Directive { Ok(Directive::ToRadix { a, b, radix }) } + 6 => { + let tuple = read_u32(&mut reader)?; + let a_len = read_u32(&mut reader)?; + let mut a = Vec::with_capacity(a_len as usize); + for _ in 0..a_len { + let mut element = Vec::new(); + for _ in 0..tuple { + element.push(Expression::read(&mut reader)?); + } + a.push(element); + } + + let bits_len = read_u32(&mut reader)?; + let mut bits = Vec::with_capacity(bits_len as usize); + for _ in 0..bits_len { + bits.push(Witness(read_u32(&mut reader)?)); + } + let sort_by_len = read_u32(&mut reader)?; + let mut sort_by = Vec::with_capacity(sort_by_len as usize); + for _ in 0..sort_by_len { + sort_by.push(read_u32(&mut reader)?); + } + Ok(Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + }) + } _ => Err(std::io::ErrorKind::InvalidData.into()), } } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +// If values are compile time and/or known during +// evaluation, we can form an output string during ACIR generation. +// Otherwise, we must store witnesses whose values will +// be fetched during the PWG stage. +pub enum LogInfo { + FinalizedOutput(String), + WitnessOutput(Vec), +} + #[test] -fn serialisation_roundtrip() { +fn serialization_roundtrip() { fn read_write(directive: Directive) -> (Directive, Directive) { let mut bytes = Vec::new(); directive.write(&mut bytes).unwrap(); diff --git a/acir/src/circuit/mod.rs b/acir/src/circuit/mod.rs index fed4e0808..a043aee20 100644 --- a/acir/src/circuit/mod.rs +++ b/acir/src/circuit/mod.rs @@ -1,10 +1,10 @@ -pub mod blackbox_functions; +pub mod black_box_functions; pub mod directives; pub mod opcodes; pub use opcodes::Opcode; use crate::native_types::Witness; -use crate::serialisation::{read_u32, write_u32}; +use crate::serialization::{read_u32, write_u32}; use rmp_serde; use serde::{Deserialize, Serialize}; @@ -29,7 +29,7 @@ impl Circuit { } #[deprecated( - note = "we want to use a serialisation strategy that is easy to implement in many languages (without ffi). use `read` instead" + note = "we want to use a serialization strategy that is easy to implement in many languages (without ffi). use `read` instead" )] pub fn from_bytes(bytes: &[u8]) -> Circuit { let mut deflater = DeflateDecoder::new(bytes); @@ -39,7 +39,7 @@ impl Circuit { } #[deprecated( - note = "we want to use a serialisation strategy that is easy to implement in many languages (without ffi).use `write` instead" + note = "we want to use a serialization strategy that is easy to implement in many languages (without ffi).use `write` instead" )] pub fn to_bytes(&self) -> Vec { let buf = rmp_serde::to_vec(&self).unwrap(); @@ -71,7 +71,7 @@ impl Circuit { // TODO (Note): we could use semver versioning from the Cargo.toml // here and then reject anything that has a major bump // - // We may also not want to do that if we do not want to couple serialisation + // We may also not want to do that if we do not want to couple serialization // with other breaking changes if version_number != VERSION_NUMBER { return Err(std::io::ErrorKind::InvalidData.into()); @@ -180,7 +180,7 @@ mod test { } #[test] - fn serialisation_roundtrip() { + fn serialization_roundtrip() { let circuit = Circuit { current_witness_index: 5, opcodes: vec![and_opcode(), range_opcode()], diff --git a/acir/src/circuit/opcodes.rs b/acir/src/circuit/opcodes.rs index 37017adde..8d6794387 100644 --- a/acir/src/circuit/opcodes.rs +++ b/acir/src/circuit/opcodes.rs @@ -1,8 +1,8 @@ use std::io::{Read, Write}; -use super::directives::Directive; +use super::directives::{Directive, LogInfo}; use crate::native_types::{Expression, Witness}; -use crate::serialisation::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}; +use crate::serialization::{read_n, read_u16, read_u32, write_bytes, write_u16, write_u32}; use crate::BlackBoxFunc; use serde::{Deserialize, Serialize}; @@ -25,7 +25,7 @@ impl Opcode { } // We have three types of opcodes allowed in the IR // Expression, BlackBoxFuncCall and Directives - // When we serialise these opcodes, we use the index + // When we serialize these opcodes, we use the index // to uniquely identify which category of opcode we are dealing with. pub(crate) fn to_index(&self) -> u8 { match self { @@ -164,6 +164,33 @@ impl std::fmt::Display for Opcode { b.last().unwrap().witness_index(), ) } + Opcode::Directive(Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + }) => { + write!(f, "DIR::PERMUTATIONSORT ")?; + write!( + f, + "(permutation size: {} {}-tuples, sort_by: {:#?}, bits: [_{}..._{}]))", + a.len(), + tuple, + sort_by, + // (Note): the bits do not have contiguous index but there are too many for display + bits.first().unwrap().witness_index(), + bits.last().unwrap().witness_index(), + ) + } + Opcode::Directive(Directive::Log(info)) => match info { + LogInfo::FinalizedOutput(output_string) => write!(f, "Log: {output_string}"), + LogInfo::WitnessOutput(witnesses) => write!( + f, + "Log: _{}..._{}", + witnesses.first().unwrap().witness_index(), + witnesses.last().unwrap().witness_index() + ), + }, } } } @@ -326,7 +353,7 @@ impl std::fmt::Debug for BlackBoxFuncCall { } #[test] -fn serialisation_roundtrip() { +fn serialization_roundtrip() { fn read_write(opcode: Opcode) -> (Opcode, Opcode) { let mut bytes = Vec::new(); opcode.write(&mut bytes).unwrap(); @@ -336,7 +363,7 @@ fn serialisation_roundtrip() { let opcode_arith = Opcode::Arithmetic(Expression::default()); - let opcode_blackbox_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall { + let opcode_black_box_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall { name: BlackBoxFunc::AES, inputs: vec![ FunctionInput { @@ -356,7 +383,7 @@ fn serialisation_roundtrip() { result: Witness(56789u32), }); - let opcodes = vec![opcode_arith, opcode_blackbox_func, opcode_directive]; + let opcodes = vec![opcode_arith, opcode_black_box_func, opcode_directive]; for opcode in opcodes { let (op, got_op) = read_write(opcode); diff --git a/acir/src/lib.rs b/acir/src/lib.rs index ff1f4a951..48a3f2df7 100644 --- a/acir/src/lib.rs +++ b/acir/src/lib.rs @@ -2,7 +2,8 @@ pub mod circuit; pub mod native_types; -mod serialisation; +mod serialization; +pub use acir_field; pub use acir_field::FieldElement; -pub use circuit::blackbox_functions::BlackBoxFunc; +pub use circuit::black_box_functions::BlackBoxFunc; diff --git a/acir/src/native_types/arithmetic.rs b/acir/src/native_types/arithmetic.rs index 56621919a..342a70819 100644 --- a/acir/src/native_types/arithmetic.rs +++ b/acir/src/native_types/arithmetic.rs @@ -1,5 +1,5 @@ use crate::native_types::{Linear, Witness}; -use crate::serialisation::{read_field_element, read_u32, write_bytes, write_u32}; +use crate::serialization::{read_field_element, read_u32, write_bytes, write_u32}; use acir_field::FieldElement; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; @@ -17,7 +17,7 @@ use super::witness::UnknownWitness; // XXX: If we allow the degree of the quotient polynomial to be arbitrary, then we will need a vector of wire values #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Expression { - // To avoid having to create intermediate variables pre-optimisation + // To avoid having to create intermediate variables pre-optimization // We collect all of the multiplication terms in the arithmetic gate // A multiplication term if of the form q_M * wL * wR // Hence this vector represents the following sum: q_M1 * wL1 * wR1 + q_M2 * wL2 * wR2 + .. + @@ -448,7 +448,7 @@ impl Expression { // A polynomial whose mul terms are non zero which do not match up with two terms in the fan-in cannot fit into one gate // An example of this is: Axy + Bx + Cy + ... // Notice how the bivariate monomial xy has two univariate monomials with their respective coefficients - // XXX: note that if x or y is zero, then we could apply a further optimisation, but this would be done in another algorithm. + // XXX: note that if x or y is zero, then we could apply a further optimization, but this would be done in another algorithm. // It would be the same as when we have zero coefficients - Can only work if wire is constrained to be zero publicly let mul_term = &self.mul_terms[0]; @@ -478,7 +478,7 @@ impl Expression { } #[test] -fn serialisation_roundtrip() { +fn serialization_roundtrip() { // Empty expression // let expr = Expression::default(); diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index 0795d4e28..b81eee65c 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -56,7 +56,7 @@ impl Witness { // We use this, so that they are pushed to the beginning of the array // // When they are pushed to the beginning of the array, they are less likely to be used in an intermediate gate -// by the optimiser, which would mean two unknowns in an equation. +// by the optimizer, which would mean two unknowns in an equation. // See Issue #20 // TODO: can we find a better solution to this? pub struct UnknownWitness(pub u32); diff --git a/acir/src/serialisation.rs b/acir/src/serialization.rs similarity index 95% rename from acir/src/serialisation.rs rename to acir/src/serialization.rs index 838e01667..1fe719936 100644 --- a/acir/src/serialisation.rs +++ b/acir/src/serialization.rs @@ -43,7 +43,7 @@ pub fn read_field_element( let bytes = read_n::(&mut r)?; - // TODO: We should not reduce here, we want the serialisation to be + // TODO: We should not reduce here, we want the serialization to be // TODO canonical let field_element = FieldElement::from_be_bytes_reduce(&bytes); diff --git a/acir_field/Cargo.toml b/acir_field/Cargo.toml index dc7872451..f789ec6e8 100644 --- a/acir_field/Cargo.toml +++ b/acir_field/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "acir_field" -version = "0.3.1" +version = "0.4.1" authors = ["Kevaundray Wedderburn "] edition = "2021" license = "MIT" @@ -10,13 +10,14 @@ description = "The field implementation being used by ACIR." [dependencies] hex = "0.4.2" -ark-bn254 = { version = "^0.3.0", optional = true, default-features = false, features = [ +ark-bn254 = { version = "^0.4.0", optional = true, default-features = false, features = [ "curve", ] } -ark-bls12-381 = { version = "^0.3.0", optional = true, default-features = false, features = [ +ark-bls12-381 = { version = "^0.4.0", optional = true, default-features = false, features = [ "curve", ] } -ark-ff = { version = "^0.3.0", optional = true, default-features = false } +ark-ff = { version = "^0.4.0", optional = true, default-features = false } +ark-serialize = { version = "^0.4.0", default-features = false } blake2 = "0.9.1" cfg-if = "1.0.0" @@ -25,9 +26,6 @@ serde = { version = "1.0.136", features = ["derive"] } num-bigint = "0.4" num-traits = "0.2.8" -[dev-dependencies] -ark-bn254 = { version = "^0.3.0", features = ["curve"] } - [features] default = ["bn254"] bn254 = ["ark-bn254", "ark-ff"] diff --git a/acir_field/src/generic_ark.rs b/acir_field/src/generic_ark.rs index fb3722ea7..2e78e6c30 100644 --- a/acir_field/src/generic_ark.rs +++ b/acir_field/src/generic_ark.rs @@ -1,5 +1,3 @@ -use ark_ff::to_bytes; -use ark_ff::FpParameters; use ark_ff::PrimeField; use ark_ff::Zero; use num_bigint::BigUint; @@ -43,7 +41,13 @@ impl std::fmt::Display for FieldElement { break; } } - return write!(f, "2{}", superscript(bit_index)); + return match bit_index { + 0 => write!(f, "1"), + 1 => write!(f, "2"), + 2 => write!(f, "4"), + 3 => write!(f, "8"), + _ => write!(f, "2{}", superscript(bit_index)), + }; } // Check if number is a multiple of a power of 2. @@ -51,7 +55,7 @@ impl std::fmt::Display for FieldElement { // we usually have numbers in the form 2^t * q + r // We focus on 2^64, 2^32, 2^16, 2^8, 2^4 because // they are common. We could extend this to a more - // general factorisation strategy, but we pay in terms of CPU time + // general factorization strategy, but we pay in terms of CPU time let mul_sign = "×"; for power in [64, 32, 16, 8, 4] { let power_of_two = BigUint::from(2_u128).pow(power); @@ -157,7 +161,7 @@ impl FieldElement { } pub fn pow(&self, exponent: &Self) -> Self { - FieldElement(self.0.pow(exponent.0.into_repr())) + FieldElement(self.0.pow(exponent.0.into_bigint())) } /// Maximum number of bits needed to represent a field element @@ -166,7 +170,7 @@ impl FieldElement { /// But the representation uses 256 bits, so the top two bits are always zero /// This method would return 254 pub const fn max_num_bits() -> u32 { - F::Params::MODULUS_BITS + F::MODULUS_BIT_SIZE } /// Maximum numbers of bytes needed to represent a field element @@ -183,7 +187,7 @@ impl FieldElement { } pub fn modulus() -> BigUint { - F::Params::MODULUS.into() + F::MODULUS.into() } /// Returns None, if the string is not a canonical /// representation of a field element; less than the order @@ -226,7 +230,7 @@ impl FieldElement { } /// Computes the inverse or returns zero if the inverse does not exist - /// Before using this FieldElement, please ensure that this behaviour is necessary + /// Before using this FieldElement, please ensure that this behavior is necessary pub fn inverse(&self) -> FieldElement { let inv = self.0.inverse().unwrap_or_else(F::zero); FieldElement(inv) @@ -243,7 +247,8 @@ impl FieldElement { } pub fn to_hex(self) -> String { - let mut bytes = to_bytes!(self.0).unwrap(); + let mut bytes = Vec::new(); + self.0.serialize_uncompressed(&mut bytes).unwrap(); bytes.reverse(); hex::encode(bytes) } @@ -257,7 +262,8 @@ impl FieldElement { // to_be_bytes! uses little endian which is why we reverse the output // TODO: Add a little endian equivalent, so the caller can use whichever one // TODO they desire - let mut bytes = to_bytes!(self.0).unwrap(); + let mut bytes = Vec::new(); + self.0.serialize_uncompressed(&mut bytes).unwrap(); bytes.reverse(); bytes } @@ -297,7 +303,7 @@ impl FieldElement { let num_elements = num_bytes / 8; let mut bytes = self.to_be_bytes(); - bytes.reverse(); // put it in big endian format. XXX(next refactor): we should be explicit about endianess. + bytes.reverse(); // put it in big endian format. XXX(next refactor): we should be explicit about endianness. bytes[0..num_elements].to_vec() } @@ -405,6 +411,28 @@ mod test { assert_eq!(res.to_be_bytes(), x.to_be_bytes()); } } + + #[test] + fn serialize_fixed_test_vectors() { + // Serialized field elements from of 0, -1, -2, -3 + let hex_strings = vec![ + "0000000000000000000000000000000000000000000000000000000000000000", + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000", + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593efffffff", + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593effffffe", + ]; + + for (i, string) in hex_strings.into_iter().enumerate() { + let minus_i_field_element = + -crate::generic_ark::FieldElement::::from(i as i128); + assert_eq!(minus_i_field_element.to_hex(), string) + } + } + #[test] + fn max_num_bits_smoke() { + let max_num_bits_bn254 = crate::generic_ark::FieldElement::::max_num_bits(); + assert_eq!(max_num_bits_bn254, 254) + } } fn mask_vector_le(bytes: &mut [u8], num_bits: usize) { diff --git a/acvm/Cargo.toml b/acvm/Cargo.toml index cbd53d0fd..0d7749780 100644 --- a/acvm/Cargo.toml +++ b/acvm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "acvm" -version = "0.3.1" +version = "0.4.1" authors = ["Kevaundray Wedderburn "] edition = "2021" license = "MIT" @@ -10,9 +10,8 @@ description = "The virtual machine that processes ACIR given a backend/proof sys [dependencies] num-bigint = "0.4" num-traits = "0.2" -acir = { version = "0.3.1", path = "../acir" } -acir_field = { version = "0.3.1", path = "../acir_field", default-features = false } -stdlib = { package = "acvm_stdlib", version = "0.3.0", path = "../stdlib" } +acir = { version = "0.4.1", path = "../acir" } +stdlib = { package = "acvm_stdlib", version = "0.4.1", path = "../stdlib" } sha2 = "0.9.3" blake2 = "0.9.1" @@ -28,8 +27,9 @@ indexmap = "1.7.0" thiserror = "1.0.21" [features] -bn254 = ["acir_field/bn254"] -bls12_381 = ["acir_field/bls12_381"] +bn254 = ["acir/bn254"] +bls12_381 = ["acir/bls12_381"] [dev-dependencies] tempfile = "3.2.0" +rand = "0.8.5" diff --git a/acvm/src/compiler.rs b/acvm/src/compiler.rs index d4b5b0ce8..bc27e875a 100644 --- a/acvm/src/compiler.rs +++ b/acvm/src/compiler.rs @@ -1,5 +1,5 @@ // The various passes that we can use over ACIR -pub mod optimisers; +pub mod optimizers; pub mod transformers; use crate::Language; @@ -9,7 +9,7 @@ use acir::{ BlackBoxFunc, }; use indexmap::IndexMap; -use optimisers::GeneralOptimiser; +use optimizers::GeneralOptimizer; use thiserror::Error; use transformers::{CSatTransformer, FallbackTransformer, IsBlackBoxSupported, R1CSTransformer}; @@ -22,21 +22,21 @@ pub enum CompileError { pub fn compile( acir: Circuit, np_language: Language, - is_blackbox_supported: IsBlackBoxSupported, + is_black_box_supported: IsBlackBoxSupported, ) -> Result { - // Instantiate the optimiser. - // Currently the optimiser and reducer are one in the same + // Instantiate the optimizer. + // Currently the optimizer and reducer are one in the same // for CSAT // Fallback transformer pass - let acir = FallbackTransformer::transform(acir, is_blackbox_supported)?; + let acir = FallbackTransformer::transform(acir, is_black_box_supported)?; - // General optimiser pass + // General optimizer pass let mut opcodes: Vec = Vec::new(); for opcode in acir.opcodes { match opcode { Opcode::Arithmetic(arith_expr) => { - opcodes.push(Opcode::Arithmetic(GeneralOptimiser::optimise(arith_expr))) + opcodes.push(Opcode::Arithmetic(GeneralOptimizer::optimize(arith_expr))) } other_gate => opcodes.push(other_gate), }; @@ -53,9 +53,9 @@ pub fn compile( // TODO: the code below is only for CSAT transformer // TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs - // TODO or at the very least, we could put all of it inside of CSATOptimiser pass + // TODO or at the very least, we could put all of it inside of CSatOptimizer pass - // Optimise the arithmetic gates by reducing them into the correct width and + // Optimize the arithmetic gates by reducing them into the correct width and // creating intermediate variables when necessary let mut transformed_gates = Vec::new(); @@ -93,6 +93,6 @@ pub fn compile( Ok(Circuit { current_witness_index, opcodes: transformed_gates, - public_inputs: acir.public_inputs, // The optimiser does not add public inputs + public_inputs: acir.public_inputs, // The optimizer does not add public inputs }) } diff --git a/acvm/src/compiler/optimisers/mod.rs b/acvm/src/compiler/optimisers/mod.rs deleted file mode 100644 index 5624e1f06..000000000 --- a/acvm/src/compiler/optimisers/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod general; - -pub use general::GeneralOpt as GeneralOptimiser; diff --git a/acvm/src/compiler/optimisers/general.rs b/acvm/src/compiler/optimizers/general.rs similarity index 87% rename from acvm/src/compiler/optimisers/general.rs rename to acvm/src/compiler/optimizers/general.rs index 65e95e13b..6f81c896e 100644 --- a/acvm/src/compiler/optimisers/general.rs +++ b/acvm/src/compiler/optimizers/general.rs @@ -6,8 +6,8 @@ use indexmap::IndexMap; pub struct GeneralOpt; impl GeneralOpt { - pub fn optimise(gate: Expression) -> Expression { - // XXX: Perhaps this optimisation can be done on the fly + pub fn optimize(gate: Expression) -> Expression { + // XXX: Perhaps this optimization can be done on the fly let gate = remove_zero_coefficients(gate); simplify_mul_terms(gate) } @@ -27,7 +27,7 @@ pub fn remove_zero_coefficients(mut gate: Expression) -> Expression { pub fn simplify_mul_terms(mut gate: Expression) -> Expression { let mut hash_map: IndexMap<(Witness, Witness), FieldElement> = IndexMap::new(); - // Canonicalise the ordering of the multiplication, lets just order by variable name + // Canonicalize the ordering of the multiplication, lets just order by variable name for (scale, w_l, w_r) in gate.mul_terms.clone().into_iter() { let mut pair = vec![w_l, w_r]; // Sort using rust sort algorithm diff --git a/acvm/src/compiler/optimizers/mod.rs b/acvm/src/compiler/optimizers/mod.rs new file mode 100644 index 000000000..2b7b95f28 --- /dev/null +++ b/acvm/src/compiler/optimizers/mod.rs @@ -0,0 +1,3 @@ +mod general; + +pub use general::GeneralOpt as GeneralOptimizer; diff --git a/acvm/src/compiler/optimisers/range.rs b/acvm/src/compiler/optimizers/range.rs similarity index 54% rename from acvm/src/compiler/optimisers/range.rs rename to acvm/src/compiler/optimizers/range.rs index 6d733cdb6..9bfa86d08 100644 --- a/acvm/src/compiler/optimisers/range.rs +++ b/acvm/src/compiler/optimizers/range.rs @@ -1,19 +1,19 @@ // XXX: We could alleviate a runtime check from noir // By casting directly // Example: -// priv z1 = x as u32 -// priv z2 = x as u16 +// priv z1 = x as u32 +// priv z2 = x as u16 // // The IR would see both casts and replace it with -// -// +// +// // priv z1 = x as u16; // priv z2 = x as u16; // -// -// Then maybe another optimisation could be done so that it transforms into +// +// Then maybe another optimization could be done so that it transforms into // // priv z1 = x as u16 // priv z2 = z1 -// This is what I would call a general optimisation, so it could live inside of the IR module -// A more specific optimisation would be to have z2 = z1 not use a gate (copy_from_to), this is more specific to plonk-aztec and would not live in this module \ No newline at end of file +// This is what I would call a general optimization, so it could live inside of the IR module +// A more specific optimization would be to have z2 = z1 not use a gate (copy_from_to), this is more specific to plonk-aztec and would not live in this module diff --git a/acvm/src/compiler/transformers/csat.rs b/acvm/src/compiler/transformers/csat.rs index 1e27d789d..72bb96925 100644 --- a/acvm/src/compiler/transformers/csat.rs +++ b/acvm/src/compiler/transformers/csat.rs @@ -8,7 +8,7 @@ use indexmap::IndexMap; // Is this more of a Reducer than an optimiser? // Should we give it all of the gates? -// Have a single optimiser that you instantiate with a width, then pass many gates through +// Have a single transformer that you instantiate with a width, then pass many gates through pub struct CSatTransformer { width: usize, } @@ -21,9 +21,9 @@ impl CSatTransformer { CSatTransformer { width } } - // Still missing dead witness optimisation. + // Still missing dead witness optimization. // To do this, we will need the whole set of arithmetic gates - // I think it can also be done before the local optimisation seen here, as dead variables will come from the user + // I think it can also be done before the local optimization seen here, as dead variables will come from the user pub fn transform( &self, gate: Expression, @@ -31,23 +31,23 @@ impl CSatTransformer { num_witness: u32, ) -> Expression { // Here we create intermediate variables and constrain them to be equal to any subset of the polynomial that can be represented as a full gate - let gate = self.full_gate_scan_optimisation(gate, intermediate_variables, num_witness); - // The last optimisation to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms + let gate = self.full_gate_scan_optimization(gate, intermediate_variables, num_witness); + // The last optimization to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms // If a gate has more than one mul term. We may need an intermediate variable for each one. Since not every variable will need to link to // the mul term, we could possibly do it that way. - // We wil call this a partial gate scan optimisation which will result in the gates being able to fit into the correct width + // We wil call this a partial gate scan optimization which will result in the gates being able to fit into the correct width let mut gate = - self.partial_gate_scan_optimisation(gate, intermediate_variables, num_witness); + self.partial_gate_scan_optimization(gate, intermediate_variables, num_witness); gate.sort(); gate } - // This optimisation will search for combinations of terms which can be represented in a single arithmetic gate + // This optimization will search for combinations of terms which can be represented in a single arithmetic gate // Case 1 : qM * wL * wR + qL * wL + qR * wR + qO * wO + qC - // This polynomial does not require any further optimisations, it can be safely represented in one gate + // This polynomial does not require any further optimizations, it can be safely represented in one gate // ie a polynomial with 1 mul(bi-variate) term and 3 (univariate) terms where 2 of those terms match the bivariate term // wL and wR, we can represent it in one gate - // GENERALISED for WIDTH: instead of the number 3, we use `WIDTH` + // GENERALIZED for WIDTH: instead of the number 3, we use `WIDTH` // // // Case 2: qM * wL * wR + qL * wL + qR * wR + qO * wO + qC + qM2 * wL2 * wR2 + qL * wL2 + qR * wR2 + qO * wO2 + qC2 @@ -65,7 +65,7 @@ impl CSatTransformer { // The polynomial now looks like so t + t2 // We can no longer extract another full gate, hence the algorithm terminates. Creating two intermediate variables t and t2. // This stage of preprocessing does not guarantee that all polynomials can fit into a gate. It only guarantees that all full gates have been extracted from each polynomial - fn full_gate_scan_optimisation( + fn full_gate_scan_optimization( &self, mut gate: Expression, intermediate_variables: &mut IndexMap, @@ -74,11 +74,11 @@ impl CSatTransformer { // We pass around this intermediate variable IndexMap, so that we do not create intermediate variables that we have created before // One instance where this might happen is t1 = wL * wR and t2 = wR * wL - // First check that this is not a simple gate which does not need optimisation + // First check that this is not a simple gate which does not need optimization // - // If the gate only has one mul term, then this algorithm cannot optimise it any further + // If the gate only has one mul term, then this algorithm cannot optimize it any further // Either it can be represented in a single arithmetic equation or it's fan-in is too large and we need intermediate variables for those - // large-fan-in optimisation is not this algorithms purpose. + // large-fan-in optimization is not this algorithms purpose. // If the gate has 0 mul terms, then it is an add gate and similarly it can either fit into a single arithmetic gate or it has a large fan-in if gate.mul_terms.len() <= 1 { return gate; @@ -98,7 +98,7 @@ impl CSatTransformer { // Check if this pair is present in the simplified fan-in // We are assuming that the fan-in/fan-out has been simplified. - // Note this function is not public, and can only be called within the optimise method, so this guarantee will always hold + // Note this function is not public, and can only be called within the optimize method, so this guarantee will always hold let index_wl = gate .linear_combinations .iter() @@ -157,15 +157,15 @@ impl CSatTransformer { // Add this element into the new gate intermediate_gate.linear_combinations.push(wire_term); } else { - // Nomore elements left in the old gate, we could stop the whole function - // We could alternative let it keep going, as it will never reach this branch again since there are nomore elements left - // XXX: Future optimisation - // nomoreleft = true + // No more elements left in the old gate, we could stop the whole function + // We could alternative let it keep going, as it will never reach this branch again since there are no more elements left + // XXX: Future optimization + // no_more_left = true } } // Constraint this intermediate_gate to be equal to the temp variable by adding it into the IndexMap // We need a unique name for our intermediate variable - // XXX: Another optimisation, which could be applied in another algorithm + // XXX: Another optimization, which could be applied in another algorithm // If two gates have a large fan-in/out and they share a few common terms, then we should create intermediate variables for them // Do some sort of subset matching algorithm for this on the terms of the polynomial let inter_var = Witness(intermediate_variables.len() as u32 + num_witness); @@ -197,16 +197,16 @@ impl CSatTransformer { new_gate } - // A partial gate scan optimisation aim to create intermediate variables in order to compress the polynomial + // A partial gate scan optimization aim to create intermediate variables in order to compress the polynomial // So that it fits within the given width - // Note that this gate follows the full gate scan optimisation. + // Note that this gate follows the full gate scan optimization. // We define the partial width as equal to the full width - 2. // This is because two of our variables cannot be used as they are linked to the multiplication terms // Example: qM1 * wL1 * wR2 + qL1 * wL3 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC // One thing to note is that the multiplication wires do not match any of the fan-in/out wires. This is guaranteed as we have - // just completed the full gate optimisation algorithm. + // just completed the full gate optimization algorithm. // - //Actually we can optimise in two ways here: We can create an intermediate variable which is equal to the fan-in terms + //Actually we can optimize in two ways here: We can create an intermediate variable which is equal to the fan-in terms // t = qL1 * wL3 + qR1 * wR4 -> width = 3 // This `t` value can only use width - 1 terms // The gate now looks like: qM1 * wL1 * wR2 + t + qR2 * wR5+ qO1 * wO5 + qC @@ -229,12 +229,12 @@ impl CSatTransformer { // The gate now looks like: t2 + qR1 * wR4+ qR2 * wR5 + qO1 * wO5 + qC // t3 = t2 + qR1 * wR4 // The gate now looks like: t3 + qR2 * wR5 + qO1 * wO5 + qC - // This took the same amount of gates, but which one is better when the width increases? Compute this and maybe do both optimisations + // This took the same amount of gates, but which one is better when the width increases? Compute this and maybe do both optimizations // naming : partial_gate_mul_first_opt and partial_gate_fan_first_opt // Also remember that since we did full gate scan, there is no way we can have a non-zero mul term along with the wL and wR terms being non-zero // // Cases, a lot of mul terms, a lot of fan-in terms, 50/50 - fn partial_gate_scan_optimisation( + fn partial_gate_scan_optimization( &self, mut gate: Expression, intermediate_variables: &mut IndexMap, @@ -243,7 +243,7 @@ impl CSatTransformer { // We will go for the easiest route, which is to convert all multiplications into additions using intermediate variables // Then use intermediate variables again to squash the fan-in, so that it can fit into the appropriate width - // First check if this polynomial actually needs a partial gate optimisation + // First check if this polynomial actually needs a partial gate optimization // There is the chance that it fits perfectly within the arithmetic gate if gate.fits_in_one_identity(self.width) { return gate; @@ -316,7 +316,7 @@ impl CSatTransformer { // keep consistency with the original equation. gate.linear_combinations.extend(added); - self.partial_gate_scan_optimisation(gate, intermediate_variables, num_witness) + self.partial_gate_scan_optimization(gate, intermediate_variables, num_witness) } } @@ -343,9 +343,9 @@ fn simple_reduction_smoke_test() { let num_witness = 4; - let optimiser = CSatTransformer::new(3); - let got_optimised_gate_a = - optimiser.transform(gate_a, &mut intermediate_variables, num_witness); + let optimizer = CSatTransformer::new(3); + let got_optimized_gate_a = + optimizer.transform(gate_a, &mut intermediate_variables, num_witness); // a = b + c + d => a - b - c - d = 0 // For width3, the result becomes: @@ -354,7 +354,7 @@ fn simple_reduction_smoke_test() { // // a - b + e = 0 let e = Witness(4); - let expected_optimised_gate_a = Expression { + let expected_optimized_gate_a = Expression { mul_terms: vec![], linear_combinations: vec![ (FieldElement::one(), a), @@ -363,7 +363,7 @@ fn simple_reduction_smoke_test() { ], q_c: FieldElement::zero(), }; - assert_eq!(expected_optimised_gate_a, got_optimised_gate_a); + assert_eq!(expected_optimized_gate_a, got_optimized_gate_a); assert_eq!(intermediate_variables.len(), 1); diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index 9b87040be..e82fa9d4b 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -4,9 +4,9 @@ pub mod compiler; pub mod pwg; -use crate::{compiler::compile, pwg::arithmetic::ArithmeticSolver}; +use crate::pwg::arithmetic::ArithmeticSolver; use acir::{ - circuit::{directives::Directive, opcodes::BlackBoxFuncCall, Circuit, Opcode, PublicInputs}, + circuit::{directives::Directive, opcodes::BlackBoxFuncCall, Circuit, Opcode}, native_types::{Expression, Witness}, BlackBoxFunc, }; @@ -67,7 +67,7 @@ pub trait PartialWitnessGenerator { let resolution = match &opcode { Opcode::Arithmetic(expr) => ArithmeticSolver::solve(initial_witness, expr), Opcode::BlackBoxFuncCall(bb_func) => { - Self::solve_blackbox_function_call(initial_witness, bb_func) + Self::solve_black_box_function_call(initial_witness, bb_func) } Opcode::Directive(directive) => Self::solve_directives(initial_witness, directive), }; @@ -88,7 +88,7 @@ pub trait PartialWitnessGenerator { self.solve(initial_witness, unsolved_opcodes) } - fn solve_blackbox_function_call( + fn solve_black_box_function_call( initial_witness: &mut BTreeMap, func_call: &BlackBoxFuncCall, ) -> Result<(), OpcodeResolutionError>; @@ -139,11 +139,11 @@ pub trait ProofSystemCompiler { /// as this in most cases will be inefficient. For this reason, we want to throw a hard error /// if the language and proof system does not line up. fn np_language(&self) -> Language; - // Returns true if the backend supports the selected blackbox function - fn blackbox_function_supported(&self, opcode: &BlackBoxFunc) -> bool; + // Returns true if the backend supports the selected black box function + fn black_box_function_supported(&self, opcode: &BlackBoxFunc) -> bool; /// Creates a Proof given the circuit description and the witness values. - /// It is important to note that the intermediate witnesses for blackbox functions will not generated + /// It is important to note that the intermediate witnesses for black box functions will not generated /// This is the responsibility of the proof system. /// /// See `SmartContract` regarding the removal of `num_witnesses` and `num_public_inputs` @@ -168,6 +168,23 @@ pub trait ProofSystemCompiler { ) -> bool; fn get_exact_circuit_size(&self, circuit: Circuit) -> u32; + + fn preprocess(&self, circuit: Circuit) -> (Vec, Vec); + + fn prove_with_pk( + &self, + circuit: Circuit, + witness_values: BTreeMap, + proving_key: Vec, + ) -> Vec; + + fn verify_with_vk( + &self, + proof: &[u8], + public_inputs: Vec, + circuit: Circuit, + verification_key: Vec, + ) -> bool; } /// Supported NP complete languages @@ -180,7 +197,7 @@ pub enum Language { pub fn hash_constraint_system(cs: &Circuit) -> [u8; 32] { let mut bytes = Vec::new(); - cs.write(&mut bytes).expect("could not serialise circuit"); + cs.write(&mut bytes).expect("could not serialize circuit"); use sha2::{digest::FixedOutput, Digest, Sha256}; let mut hasher = Sha256::new(); @@ -190,31 +207,26 @@ pub fn hash_constraint_system(cs: &Circuit) -> [u8; 32] { } #[deprecated( - note = "For backwards compatibility, this method allows you to derive _sensible_ defaults for blackbox function support based on the np language. \n Backends should simply specify what they support." + note = "For backwards compatibility, this method allows you to derive _sensible_ defaults for black box function support based on the np language. \n Backends should simply specify what they support." )] // This is set to match the previous functionality that we had -// Where we could deduce what blackbox functions were supported +// Where we could deduce what black box functions were supported // by knowing the np complete language -pub fn default_is_blackbox_supported( +pub fn default_is_black_box_supported( language: Language, ) -> compiler::transformers::IsBlackBoxSupported { - // R1CS does not support any of the blackbox functions by default. + // R1CS does not support any of the black box functions by default. // The compiler will replace those that it can -- ie range, xor, and - fn r1cs_is_supported(opcode: &BlackBoxFunc) -> bool { - match opcode { - _ => false, - } + fn r1cs_is_supported(_opcode: &BlackBoxFunc) -> bool { + false } - // PLONK supports most of the blackbox functions by default + // PLONK supports most of the black box functions by default // The ones which are not supported, the acvm compiler will // attempt to transform into supported gates. If these are also not available // then a compiler error will be emitted. fn plonk_is_supported(opcode: &BlackBoxFunc) -> bool { - match opcode { - BlackBoxFunc::AES => false, - _ => true, - } + !matches!(opcode, BlackBoxFunc::AES) } match language { diff --git a/acvm/src/pwg.rs b/acvm/src/pwg.rs index 1be89ef19..a650e7b89 100644 --- a/acvm/src/pwg.rs +++ b/acvm/src/pwg.rs @@ -11,11 +11,12 @@ use std::collections::BTreeMap; pub mod arithmetic; // Directives pub mod directives; -// blackbox functions +// black box functions pub mod hash; pub mod logic; pub mod range; pub mod signature; +pub mod sorting; // Returns the concrete value for a particular witness // If the witness has no assignment, then diff --git a/acvm/src/pwg/arithmetic.rs b/acvm/src/pwg/arithmetic.rs index 7ac522413..2d611db02 100644 --- a/acvm/src/pwg/arithmetic.rs +++ b/acvm/src/pwg/arithmetic.rs @@ -29,6 +29,7 @@ impl ArithmeticSolver { initial_witness: &mut BTreeMap, gate: &Expression, ) -> Result<(), OpcodeResolutionError> { + let gate = &ArithmeticSolver::evaluate(gate, initial_witness); // Evaluate multiplication term let mul_result = ArithmeticSolver::solve_mul_term(gate, initial_witness); // Evaluate the fan-in terms @@ -124,29 +125,44 @@ impl ArithmeticSolver { witness_assignments: &BTreeMap, ) -> MulTerm { // First note that the mul term can only contain one/zero term - // We are assuming it has been optimised. + // We are assuming it has been optimized. match arith_gate.mul_terms.len() { 0 => MulTerm::Solved(FieldElement::zero()), - 1 => { - let q_m = &arith_gate.mul_terms[0].0; - let w_l = &arith_gate.mul_terms[0].1; - let w_r = &arith_gate.mul_terms[0].2; - - // Check if these values are in the witness assignments - let w_l_value = witness_assignments.get(w_l); - let w_r_value = witness_assignments.get(w_r); - - match (w_l_value, w_r_value) { - (None, None) => MulTerm::TooManyUnknowns, - (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r), - (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l), - (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r), - } - } + 1 => ArithmeticSolver::solve_mul_term_helper( + &arith_gate.mul_terms[0], + witness_assignments, + ), _ => panic!("Mul term in the arithmetic gate must contain either zero or one term"), } } + fn solve_mul_term_helper( + term: &(FieldElement, Witness, Witness), + witness_assignments: &BTreeMap, + ) -> MulTerm { + let (q_m, w_l, w_r) = term; + // Check if these values are in the witness assignments + let w_l_value = witness_assignments.get(w_l); + let w_r_value = witness_assignments.get(w_r); + + match (w_l_value, w_r_value) { + (None, None) => MulTerm::TooManyUnknowns, + (Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r), + (None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l), + (Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r), + } + } + + fn solve_fan_in_term_helper( + term: &(FieldElement, Witness), + witness_assignments: &BTreeMap, + ) -> Option { + let (q_l, w_l) = term; + // Check if we have w_l + let w_l_value = witness_assignments.get(w_l); + w_l_value.map(|a| *q_l * *a) + } + /// Returns the summation of all of the variables, plus the unknown variable /// Returns None, if there is more than one unknown variable /// We cannot assign @@ -163,19 +179,14 @@ impl ArithmeticSolver { let mut result = FieldElement::zero(); for term in arith_gate.linear_combinations.iter() { - let q_l = term.0; - let w_l = &term.1; - - // Check if we have w_l - let w_l_value = witness_assignments.get(w_l); - - match w_l_value { - Some(a) => result += q_l * *a, + let value = ArithmeticSolver::solve_fan_in_term_helper(term, witness_assignments); + match value { + Some(a) => result += a, None => { unknown_variable = *term; num_unknowns += 1; } - }; + } // If we have more than 1 unknown, then we cannot solve this equation if num_unknowns > 1 { @@ -189,6 +200,39 @@ impl ArithmeticSolver { GateStatus::GateSolvable(result, unknown_variable) } + + // Partially evaluate the gate using the known witnesses + pub fn evaluate( + expr: &Expression, + initial_witness: &BTreeMap, + ) -> Expression { + let mut result = Expression::default(); + for &(c, w1, w2) in &expr.mul_terms { + let mul_result = ArithmeticSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness); + match mul_result { + MulTerm::OneUnknown(v, w) => { + if !v.is_zero() { + result.linear_combinations.push((v, w)); + } + } + MulTerm::TooManyUnknowns => { + if !c.is_zero() { + result.mul_terms.push((c, w1, w2)); + } + } + MulTerm::Solved(f) => result.q_c += f, + } + } + for &(c, w) in &expr.linear_combinations { + if let Some(f) = ArithmeticSolver::solve_fan_in_term_helper(&(c, w), initial_witness) { + result.q_c += f; + } else if !c.is_zero() { + result.linear_combinations.push((c, w)); + } + } + result.q_c += expr.q_c; + result + } } #[test] diff --git a/acvm/src/pwg/directives.rs b/acvm/src/pwg/directives.rs index 94bc2d60b..67819efb3 100644 --- a/acvm/src/pwg/directives.rs +++ b/acvm/src/pwg/directives.rs @@ -1,12 +1,16 @@ -use std::collections::BTreeMap; +use std::{cmp::Ordering, collections::BTreeMap}; -use acir::{circuit::directives::Directive, native_types::Witness, FieldElement}; +use acir::{ + circuit::directives::{Directive, LogInfo}, + native_types::Witness, + FieldElement, +}; use num_bigint::BigUint; use num_traits::{One, Zero}; use crate::OpcodeResolutionError; -use super::{get_value, witness_to_value}; +use super::{get_value, sorting::route, witness_to_value}; pub fn solve_directives( initial_witness: &mut BTreeMap, @@ -45,8 +49,16 @@ pub fn solve_directives( (&int_a % &int_b, &int_a / &int_b) }; - initial_witness.insert(*q, FieldElement::from_be_bytes_reduce(&int_q.to_bytes_be())); - initial_witness.insert(*r, FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be())); + insert_witness( + *q, + FieldElement::from_be_bytes_reduce(&int_q.to_bytes_be()), + initial_witness, + )?; + insert_witness( + *r, + FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be()), + initial_witness, + )?; Ok(()) } @@ -59,8 +71,16 @@ pub fn solve_directives( let int_b: BigUint = &int_a % &pow; let int_c: BigUint = (&int_a - &int_b) / &pow; - initial_witness.insert(*b, FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be())); - initial_witness.insert(*c, FieldElement::from_be_bytes_reduce(&int_c.to_bytes_be())); + insert_witness( + *b, + FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be()), + initial_witness, + )?; + insert_witness( + *c, + FieldElement::from_be_bytes_reduce(&int_c.to_bytes_be()), + initial_witness, + )?; Ok(()) } @@ -78,16 +98,7 @@ pub fn solve_directives( } else { FieldElement::zero() }; - match initial_witness.entry(b[i]) { - std::collections::btree_map::Entry::Vacant(e) => { - e.insert(v); - } - std::collections::btree_map::Entry::Occupied(e) => { - if e.get() != &v { - return Err(OpcodeResolutionError::UnsatisfiedConstrain); - } - } - } + insert_witness(b[i], v, initial_witness)?; } Ok(()) @@ -105,10 +116,115 @@ pub fn solve_directives( let int_r = &int_a - &bb; let int_b = &bb >> (bit_size - 1); - initial_witness.insert(*b, FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be())); - initial_witness.insert(*r, FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be())); + insert_witness( + *b, + FieldElement::from_be_bytes_reduce(&int_b.to_bytes_be()), + initial_witness, + )?; + insert_witness( + *r, + FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be()), + initial_witness, + )?; + + Ok(()) + } + Directive::PermutationSort { + inputs: a, + tuple, + bits, + sort_by, + } => { + let mut val_a = Vec::new(); + let mut base = Vec::new(); + for (i, element) in a.iter().enumerate() { + assert_eq!(element.len(), *tuple as usize); + let mut element_val = Vec::with_capacity(*tuple as usize + 1); + for e in element { + element_val.push(get_value(e, initial_witness)?); + } + let field_i = FieldElement::from(i as i128); + element_val.push(field_i); + base.push(field_i); + val_a.push(element_val); + } + val_a.sort_by(|a, b| { + for i in sort_by { + let int_a = BigUint::from_bytes_be(&a[*i as usize].to_be_bytes()); + let int_b = BigUint::from_bytes_be(&b[*i as usize].to_be_bytes()); + let cmp = int_a.cmp(&int_b); + if cmp != Ordering::Equal { + return cmp; + } + } + Ordering::Equal + }); + let b = val_a.iter().map(|a| *a.last().unwrap()).collect(); + let control = route(base, b); + for (w, value) in bits.iter().zip(control) { + let value = if value { + FieldElement::one() + } else { + FieldElement::zero() + }; + insert_witness(*w, value, initial_witness)?; + } + Ok(()) + } + Directive::Log(info) => { + let witnesses = match info { + LogInfo::FinalizedOutput(output_string) => { + println!("{output_string}"); + return Ok(()); + } + LogInfo::WitnessOutput(witnesses) => witnesses, + }; + + if witnesses.len() == 1 { + let witness = &witnesses[0]; + let log_value = witness_to_value(initial_witness, *witness)?; + println!("{}", log_value.to_hex()); + + return Ok(()); + } + + // If multiple witnesses are to be fetched for a log directive, + // it assumed that an array is meant to be printed to standard output + // + // Collect all field element values corresponding to the given witness indices + // and convert them to hex strings. + let mut elements_as_hex = Vec::with_capacity(witnesses.len()); + for witness in witnesses { + let element = witness_to_value(initial_witness, *witness)?; + elements_as_hex.push(element.to_hex()); + } + + // Join all of the hex strings using a comma + let comma_separated_elements = elements_as_hex.join(","); + + let output_witnesses_string = "[".to_owned() + &comma_separated_elements + "]"; + + println!("{output_witnesses_string}"); Ok(()) } } } + +fn insert_witness( + w: Witness, + value: FieldElement, + initial_witness: &mut BTreeMap, +) -> Result<(), OpcodeResolutionError> { + match initial_witness.entry(w) { + std::collections::btree_map::Entry::Vacant(e) => { + e.insert(value); + } + std::collections::btree_map::Entry::Occupied(e) => { + if e.get() != &value { + return Err(OpcodeResolutionError::UnsatisfiedConstrain); + } + } + } + Ok(()) +} diff --git a/acvm/src/pwg/logic.rs b/acvm/src/pwg/logic.rs index 0dff9451f..57579ecf4 100644 --- a/acvm/src/pwg/logic.rs +++ b/acvm/src/pwg/logic.rs @@ -59,10 +59,12 @@ impl LogicSolver { } // TODO: Is there somewhere else that we can put this? // TODO: extraction methods are needed for some opcodes like logic and range -pub(crate) fn extract_input_output(gc: &BlackBoxFuncCall) -> (Witness, Witness, Witness, u32) { - let a = &gc.inputs[0]; - let b = &gc.inputs[1]; - let result = &gc.outputs[0]; +pub(crate) fn extract_input_output( + bb_func_call: &BlackBoxFuncCall, +) -> (Witness, Witness, Witness, u32) { + let a = &bb_func_call.inputs[0]; + let b = &bb_func_call.inputs[1]; + let result = &bb_func_call.outputs[0]; // The num_bits variable should be the same for all witnesses assert_eq!( diff --git a/acvm/src/pwg/sorting.rs b/acvm/src/pwg/sorting.rs new file mode 100644 index 000000000..217d0c0c2 --- /dev/null +++ b/acvm/src/pwg/sorting.rs @@ -0,0 +1,392 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use acir::FieldElement; + +// A sorting network is a graph of connected switches +// It is defined recursively so here we only keep track of the outer layer of switches +struct SortingNetwork { + n: usize, // size of the network + x_inputs: Vec, // inputs of the network + y_inputs: Vec, // outputs of the network + x_values: BTreeMap, // map for matching a y value with a x value + y_values: BTreeMap, // map for matching a x value with a y value + inner_x: Vec, // positions after the switch_x + inner_y: Vec, // positions after the sub-networks, and before the switch_y + switch_x: Vec, // outer switches for the inputs + switch_y: Vec, // outer switches for the outputs + free: BTreeSet, // outer switches available for looping +} + +impl SortingNetwork { + fn new(n: usize) -> SortingNetwork { + let free_len = (n - 1) / 2; + let mut free = BTreeSet::new(); + for i in 0..free_len { + free.insert(i); + } + SortingNetwork { + n, + x_inputs: Vec::with_capacity(n), + y_inputs: Vec::with_capacity(n), + x_values: BTreeMap::new(), + y_values: BTreeMap::new(), + inner_x: Vec::with_capacity(n), + inner_y: Vec::with_capacity(n), + switch_x: Vec::with_capacity(n / 2), + switch_y: Vec::with_capacity(free_len), + free, + } + } + + fn init(&mut self, inputs: Vec, outputs: Vec) { + let n = self.n; + assert_eq!(inputs.len(), outputs.len()); + assert_eq!(inputs.len(), n); + + self.x_inputs = inputs; + self.y_inputs = outputs; + for i in 0..self.n { + self.x_values.insert(self.x_inputs[i], i); + self.y_values.insert(self.y_inputs[i], i); + } + self.switch_x = vec![false; n / 2]; + self.switch_y = vec![false; (n - 1) / 2]; + self.inner_x = vec![FieldElement::zero(); n]; + self.inner_y = vec![FieldElement::zero(); n]; + + //Route the single wires so we do not need to handle this case later on + self.inner_y[n - 1] = self.y_inputs[n - 1]; + if n % 2 == 0 { + self.inner_y[n / 2 - 1] = self.y_inputs[n - 2]; + } else { + self.inner_x[n - 1] = self.x_inputs[n - 1]; + } + } + + //route a wire from outputs to its value in the inputs + fn route_out_wire(&mut self, y: usize, sub: bool) -> usize { + // sub <- y + if self.is_single_y(y) { + assert!(sub); + } else { + let port = y % 2 != 0; + let s1 = sub ^ port; + let inner = self.compute_inner(y, s1); + self.configure_y(y, s1, inner); + } + // x <- sub + let x = self.x_values.remove(&self.y_inputs[y]).unwrap(); + if !self.is_single_x(x) { + let port2 = x % 2 != 0; + let s2 = sub ^ port2; + let inner = self.compute_inner(x, s2); + self.configure_x(x, s2, inner); + } + x + } + + //route a wire from inputs to its value in the outputs + fn route_in_wire(&mut self, x: usize, sub: bool) -> usize { + // x -> sub + assert!(!self.is_single_x(x)); + let port = x % 2 != 0; + let s1 = sub ^ port; + let inner = self.compute_inner(x, s1); + self.configure_x(x, s1, inner); + + // sub -> y + let y = self.y_values.remove(&self.x_inputs[x]).unwrap(); + if !self.is_single_y(y) { + let port = y % 2 != 0; + let s2 = sub ^ port; + let inner = self.compute_inner(y, s2); + self.configure_y(y, s2, inner); + } + y + } + + //update the computed switch and inner values for an input wire + fn configure_x(&mut self, x: usize, switch: bool, inner: usize) { + self.inner_x[inner] = self.x_inputs[x]; + self.switch_x[x / 2] = switch; + } + + //update the computed switch and inner values for an output wire + fn configure_y(&mut self, y: usize, switch: bool, inner: usize) { + self.inner_y[inner] = self.y_inputs[y]; + self.switch_y[y / 2] = switch; + } + + // returns the other wire belonging to the same switch + fn sibling(index: usize) -> usize { + index + 1 - 2 * (index % 2) + } + + // returns a free switch + fn take(&mut self) -> Option { + self.free.first().copied() + } + + fn is_single_x(&self, a: usize) -> bool { + let n = self.x_inputs.len(); + n % 2 == 1 && a == n - 1 + } + + fn is_single_y(&mut self, a: usize) -> bool { + let n = self.x_inputs.len(); + a >= n - 2 + n % 2 + } + + // compute the inner position of idx through its switch + fn compute_inner(&self, idx: usize, switch: bool) -> usize { + if switch ^ (idx % 2 == 1) { + idx / 2 + self.n / 2 + } else { + idx / 2 + } + } + + fn new_start(&mut self) -> (Option, usize) { + let next = self.take(); + if let Some(switch) = next { + (next, 2 * switch) + } else { + (None, 0) + } + } +} + +// Computes the control bits of the sorting network which transform inputs into outputs +// implementation is based on https://www.mdpi.com/2227-7080/10/1/16 +pub fn route(inputs: Vec, outputs: Vec) -> Vec { + assert_eq!(inputs.len(), outputs.len()); + match inputs.len() { + 0 => Vec::new(), + 1 => { + assert_eq!(inputs[0], outputs[0]); + Vec::new() + } + 2 => { + if inputs[0] == outputs[0] { + assert_eq!(inputs[1], outputs[1]); + vec![false] + } else { + assert_eq!(inputs[1], outputs[0]); + assert_eq!(inputs[0], outputs[1]); + vec![true] + } + } + _ => { + let n = inputs.len(); + + let mut result; + let n1 = n / 2; + let in_sub1; + let out_sub1; + let in_sub2; + let out_sub2; + + // process the outer layer in a code block so that the intermediate data is cleared before recursion + { + let mut network = SortingNetwork::new(n); + network.init(inputs, outputs); + + //We start with the last single wire + let mut out_idx = n - 1; + let mut start_sub = true; //it is connected to the lower inner network + let mut switch = None; + let mut start = None; + + while !network.free.is_empty() { + // the processed switch is no more available + if let Some(free_switch) = switch { + network.free.remove(&free_switch); + } + + // connect the output wire to its matching input + let in_idx = network.route_out_wire(out_idx, start_sub); + if network.is_single_x(in_idx) { + start_sub = !start_sub; //We need to restart, but did not complete the loop so we switch the sub network + (start, out_idx) = network.new_start(); + switch = start; + continue; + } + + // loop from the sibling + let next = SortingNetwork::sibling(in_idx); + // connect the input wire to its matching output, using the other sub-network + out_idx = network.route_in_wire(next, !start_sub); + switch = Some(out_idx / 2); + if start == switch || network.is_single_y(out_idx) { + //loop is complete, need a fresh start + (start, out_idx) = network.new_start(); + switch = start; + } else { + // we loop back from the sibling + out_idx = SortingNetwork::sibling(out_idx); + } + } + //All the wires are connected, we can now route the sub-networks + result = network.switch_x; + result.extend(network.switch_y); + in_sub1 = network.inner_x[0..n1].to_vec(); + in_sub2 = network.inner_x[n1..].to_vec(); + out_sub1 = network.inner_y[0..n1].to_vec(); + out_sub2 = network.inner_y[n1..].to_vec(); + } + let s1 = route(in_sub1, out_sub1); + result.extend(s1); + let s2 = route(in_sub2, out_sub2); + result.extend(s2); + result + } + } +} + +#[cfg(test)] +mod test { + use crate::pwg::sorting::route; + use acir::FieldElement; + use rand::prelude::*; + + pub fn execute_network(config: Vec, inputs: Vec) -> Vec { + let n = inputs.len(); + if n == 1 { + return inputs; + } + let mut in1 = Vec::new(); + let mut in2 = Vec::new(); + //layer 1: + for i in 0..n / 2 { + if config[i] { + in1.push(inputs[2 * i + 1]); + in2.push(inputs[2 * i]); + } else { + in1.push(inputs[2 * i]); + in2.push(inputs[2 * i + 1]); + } + } + if n % 2 == 1 { + in2.push(*inputs.last().unwrap()); + } + let n2 = n / 2 + (n - 1) / 2; + let n3 = n2 + switch_nb(n / 2); + let mut result = Vec::new(); + let out1 = execute_network(config[n2..n3].to_vec(), in1); + let out2 = execute_network(config[n3..].to_vec(), in2); + //last layer: + for i in 0..(n - 1) / 2 { + if config[n / 2 + i] { + result.push(out2[i]); + result.push(out1[i]); + } else { + result.push(out1[i]); + result.push(out2[i]); + } + } + if n % 2 == 0 { + result.push(*out1.last().unwrap()); + result.push(*out2.last().unwrap()); + } else { + result.push(*out2.last().unwrap()) + } + result + } + + // returns the number of switches in the network + pub fn switch_nb(n: usize) -> usize { + let mut s = 0; + for i in 0..n { + s += f64::from((i + 1) as u32).log2().ceil() as usize; + } + s + } + + #[test] + fn test_route() { + //basic tests + let a = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, false, false]); + + let a = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(1_i128), + FieldElement::from(3_i128), + FieldElement::from(2_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, false, true]); + + let a = vec![ + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(3_i128), + FieldElement::from(2_i128), + FieldElement::from(1_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![true, true, true]); + + let a = vec![ + FieldElement::from(0_i128), + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + ]; + let b = vec![ + FieldElement::from(2_i128), + FieldElement::from(3_i128), + FieldElement::from(0_i128), + FieldElement::from(1_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, true, true, true, true]); + + let a = vec![ + FieldElement::from(0_i128), + FieldElement::from(1_i128), + FieldElement::from(2_i128), + FieldElement::from(3_i128), + FieldElement::from(4_i128), + ]; + let b = vec![ + FieldElement::from(0_i128), + FieldElement::from(3_i128), + FieldElement::from(4_i128), + FieldElement::from(2_i128), + FieldElement::from(1_i128), + ]; + let c = route(a, b); + assert_eq!(c, vec![false, false, false, true, false, true, false, true]); + + // random tests + for i in 2..50 { + let mut a = vec![FieldElement::zero()]; + for j in 0..i - 1 { + a.push(a[j] + FieldElement::one()); + } + let mut rng = rand::thread_rng(); + let mut b = a.clone(); + b.shuffle(&mut rng); + let c = route(a.clone(), b.clone()); + assert_eq!(b, execute_network(c, a)); + } + } +} diff --git a/cspell.json b/cspell.json new file mode 100644 index 000000000..8a4bf70c6 --- /dev/null +++ b/cspell.json @@ -0,0 +1,44 @@ +{ + "version": "0.2", + "words": [ + "blackbox", + // In code + // + "acir", + "ACIR", + "ACVM", + "Axyz", + "arithmetization", + "bivariate", + "canonicalize", + "coeff", + "consts", + "csat", + "decomp", + "deflater", + "endianness", + "euclidian", + "hasher", + "Merkle", + "OddRange", + "Pedersen", + "PLONKC", + "prehashed", + "pubkey", + "repr", + "secp", + "Schnorr", + "Shleft", + "Shright", + "stdlib", + "struct", + "TORADIX", + // Dependencies + // + "bufread", + "flate", + "indexmap", + "thiserror", + "typenum" + ] +} diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 08f72b692..5d6ccc1b2 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -1,11 +1,10 @@ [package] name = "acvm_stdlib" -version = "0.3.1" +version = "0.4.1" edition = "2021" license = "MIT" description = "The ACVM standard library." # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -acir = { version = "0.3.1", path = "../acir" } -acir_field = { version = "0.3.1", path = "../acir_field", default-features = true } +acir = { version = "0.4.1", path = "../acir", features = ["bn254"] } diff --git a/stdlib/src/fallback.rs b/stdlib/src/fallback.rs index 818c8b874..cc95fb993 100644 --- a/stdlib/src/fallback.rs +++ b/stdlib/src/fallback.rs @@ -1,9 +1,9 @@ use crate::helpers::VariableStore; use acir::{ + acir_field::FieldElement, circuit::{directives::Directive, Opcode}, native_types::{Expression, Witness}, }; -use acir_field::FieldElement; // Perform bit decomposition on the provided expression #[deprecated(note = "use bit_decomposition function instead")] @@ -68,7 +68,7 @@ pub(crate) fn bit_decomposition( bit_decomp_constraint.sort(); // TODO: we have an issue open to check if this is needed. Ideally, we remove it. new_gates.push(Opcode::Arithmetic(bit_decomp_constraint)); - (new_gates, bit_vector, variables.finalise()) + (new_gates, bit_vector, variables.finalize()) } // Range constraint @@ -140,7 +140,7 @@ pub fn xor( let two = FieldElement::from(2_i128); // Build an xor expression - // TODO: check this is the correct arithmetisation + // TODO: check this is the correct arithmetization let mut xor_expr = Expression::default(); for (a_bit, b_bit) in a_bits.into_iter().zip(b_bits) { xor_expr.term_addition(two_pow, a_bit); diff --git a/stdlib/src/helpers.rs b/stdlib/src/helpers.rs index 4eea6a518..5ab258368 100644 --- a/stdlib/src/helpers.rs +++ b/stdlib/src/helpers.rs @@ -17,7 +17,7 @@ impl<'a> VariableStore<'a> { witness } - pub fn finalise(self) -> u32 { + pub fn finalize(self) -> u32 { *self.witness_index } }