Skip to content

Commit

Permalink
fix: Allow macros to change types on each iteration of a comptime loop (
Browse files Browse the repository at this point in the history
noir-lang/noir#6105)

chore: Schnorr signature verification in Noir (noir-lang/noir#5437)
feat: Implement solver for mov_registers_to_registers (noir-lang/noir#6089)
  • Loading branch information
AztecBot committed Sep 19, 2024
2 parents 34399d0 + ed91acb commit c50fb94
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .noir-sync-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4170c55019bd27fd51be8a46637514dfe86de53c
0864e7c945089cc06f8cc9e5c7d933c465d8c892
18 changes: 17 additions & 1 deletion noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ impl<'context> Elaborator<'context> {
0
}

pub fn unify(
pub(super) fn unify(
&mut self,
actual: &Type,
expected: &Type,
Expand All @@ -644,6 +644,22 @@ impl<'context> Elaborator<'context> {
}
}

/// Do not apply type bindings even after a successful unification.
/// This function is used by the interpreter for some comptime code
/// which can change types e.g. on each iteration of a for loop.
pub fn unify_without_applying_bindings(
&mut self,
actual: &Type,
expected: &Type,
file: fm::FileId,
make_error: impl FnOnce() -> TypeCheckError,
) {
let mut bindings = TypeBindings::new();
if actual.try_unify(expected, &mut bindings).is_err() {
self.errors.push((make_error().into(), file));
}
}

/// Wrapper of Type::unify_with_coercions using self.errors
pub(super) fn unify_with_coercions(
&mut self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1303,9 +1303,11 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
// Macro calls are typed as type variables during type checking.
// Now that we know the type we need to further unify it in case there
// are inconsistencies or the type needs to be known.
// We don't commit any type bindings made this way in case the type of
// the macro result changes across loop iterations.
let expected_type = self.elaborator.interner.id_type(id);
let actual_type = result.get_type();
self.unify(&actual_type, &expected_type, location);
self.unify_without_binding(&actual_type, &expected_type, location);
}
Ok(result)
}
Expand All @@ -1319,16 +1321,14 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
}
}

fn unify(&mut self, actual: &Type, expected: &Type, location: Location) {
// We need to swap out the elaborator's file since we may be
// in a different one currently, and it uses that for the error location.
let old_file = std::mem::replace(&mut self.elaborator.file, location.file);
self.elaborator.unify(actual, expected, || TypeCheckError::TypeMismatch {
expected_typ: expected.to_string(),
expr_typ: actual.to_string(),
expr_span: location.span,
fn unify_without_binding(&mut self, actual: &Type, expected: &Type, location: Location) {
self.elaborator.unify_without_applying_bindings(actual, expected, location.file, || {
TypeCheckError::TypeMismatch {
expected_typ: expected.to_string(),
expr_typ: actual.to_string(),
expr_span: location.span,
}
});
self.elaborator.file = old_file;
}

fn evaluate_method_call(
Expand Down
28 changes: 26 additions & 2 deletions noir/noir-repo/compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ fn numeric_generics_type_kind_mismatch() {
}
global M: u16 = 3;
fn main() {
let _ = bar::<M>();
}
Expand All @@ -1972,7 +1972,7 @@ fn numeric_generics_value_kind_mismatch_u32_u64() {
}
impl<T, let MaxLen: u32> BoundedVec<T, MaxLen> {
pub fn extend_from_bounded_vec<let Len: u32>(&mut self, _vec: BoundedVec<T, Len>) {
pub fn extend_from_bounded_vec<let Len: u32>(&mut self, _vec: BoundedVec<T, Len>) {
// We do this to avoid an unused variable warning on `self`
let _ = self.len;
for _ in 0..Len { }
Expand Down Expand Up @@ -3722,5 +3722,29 @@ fn use_numeric_generic_in_trait_method() {
"#;

let errors = get_program_errors(src);
println!("{errors:?}");
assert_eq!(errors.len(), 0);
}

#[test]
fn macro_result_type_mismatch() {
let src = r#"
fn main() {
comptime {
let x = unquote!(quote { "test" });
let _: Field = x;
}
}
comptime fn unquote(q: Quoted) -> Quoted {
q
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::TypeMismatch { .. })
));
}
15 changes: 15 additions & 0 deletions noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ impl EmbeddedCurveScalar {
let (a,b) = crate::field::bn254::decompose(scalar);
EmbeddedCurveScalar { lo: a, hi: b }
}

//Bytes to scalar: take the first (after the specified offset) 16 bytes of the input as the lo value, and the next 16 bytes as the hi value
#[field(bn254)]
fn from_bytes(bytes: [u8; 64], offset: u32) -> EmbeddedCurveScalar {
let mut v = 1;
let mut lo = 0 as Field;
let mut hi = 0 as Field;
for i in 0..16 {
lo = lo + (bytes[offset+31 - i] as Field) * v;
hi = hi + (bytes[offset+15 - i] as Field) * v;
v = v * 256;
}
let sig_s = crate::embedded_curve_ops::EmbeddedCurveScalar { lo, hi };
sig_s
}
}

impl Eq for EmbeddedCurveScalar {
Expand Down
65 changes: 65 additions & 0 deletions noir/noir-repo/noir_stdlib/src/schnorr.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use crate::collections::vec::Vec;
use crate::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar};

#[foreign(schnorr_verify)]
// docs:start:schnorr_verify
pub fn verify_signature<let N: u32>(
Expand All @@ -20,3 +23,65 @@ pub fn verify_signature_slice(
// docs:end:schnorr_verify_slice
{}

pub fn verify_signature_noir<let N: u32>(public_key: EmbeddedCurvePoint, signature: [u8; 64], message: [u8; N]) -> bool {
//scalar lo/hi from bytes
let sig_s = EmbeddedCurveScalar::from_bytes(signature, 0);
let sig_e = EmbeddedCurveScalar::from_bytes(signature, 32);
// pub_key is on Grumpkin curve
let mut is_ok = (public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17)
& (!public_key.is_infinite);

if ((sig_s.lo != 0) | (sig_s.hi != 0)) & ((sig_e.lo != 0) | (sig_e.hi != 0)) {
let (r_is_infinite, result) = calculate_signature_challenge(public_key, sig_s, sig_e, message);

is_ok = !r_is_infinite;
for i in 0..32 {
is_ok &= result[i] == signature[32 + i];
}
}
is_ok
}

pub fn assert_valid_signature<let N: u32>(public_key: EmbeddedCurvePoint, signature: [u8; 64], message: [u8; N]) {
//scalar lo/hi from bytes
let sig_s = EmbeddedCurveScalar::from_bytes(signature, 0);
let sig_e = EmbeddedCurveScalar::from_bytes(signature, 32);

// assert pub_key is on Grumpkin curve
assert(public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17);
assert(public_key.is_infinite == false);
// assert signature is not null
assert((sig_s.lo != 0) | (sig_s.hi != 0));
assert((sig_e.lo != 0) | (sig_e.hi != 0));

let (r_is_infinite, result) = calculate_signature_challenge(public_key, sig_s, sig_e, message);

assert(!r_is_infinite);
for i in 0..32 {
assert(result[i] == signature[32 + i]);
}
}

fn calculate_signature_challenge<let N: u32>(
public_key: EmbeddedCurvePoint,
sig_s: EmbeddedCurveScalar,
sig_e: EmbeddedCurveScalar,
message: [u8; N]
) -> (bool, [u8; 32]) {
let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false };
let r = crate::embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]);
// compare the _hashes_ rather than field elements modulo r
let pedersen_hash = crate::hash::pedersen_hash([r.x, public_key.x, public_key.y]);
let pde: [u8; 32] = pedersen_hash.to_be_bytes();

let mut hash_input = [0; N + 32];
for i in 0..32 {
hash_input[i] = pde[i];
}
for i in 0..N {
hash_input[32+i] = message[i];
}

let result = crate::hash::blake2s(hash_input);
(r.is_infinite, result)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "macro_result_type"
type = "bin"
authors = [""]
compiler_version = ">=0.34.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn main() {
comptime
{
let signature = "hello".as_ctstring();
let string = signature.as_quoted_str!();
let result = half(string);
assert_eq(result, 2);
}
}

comptime fn half<let N: u32>(_s: str<N>) -> u32 {
N / 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "comptime_change_type_each_iteration"
type = "bin"
authors = [""]
compiler_version = ">=0.34.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
fn main() {
comptime
{
for i in 9..11 {
// Lengths are different on each iteration:
// foo9, foo10
let name = f"foo{i}".as_ctstring().as_quoted_str!();

// So to call `from_signature` we need to delay the type check
// by quoting the function call so that we re-typecheck on each iteration
let hash = std::meta::unquote!(quote { from_signature($name) });
assert(hash > 3);
}
}
}

fn from_signature<let N: u32>(_signature: str<N>) -> u32 {
N
}
104 changes: 2 additions & 102 deletions noir/noir-repo/test_programs/execution_success/schnorr/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ fn main(
// Regression for issue #2421
// We want to make sure that we can accurately verify a signature whose message is a slice vs. an array
let message_field_bytes: [u8; 10] = message_field.to_be_bytes();
let mut message2 = [0; 42];
for i in 0..10 {
assert(message[i] == message_field_bytes[i]);
message2[i] = message[i];
}

// Is there ever a situation where someone would want
// to ensure that a signature was invalid?
Expand All @@ -27,102 +22,7 @@ fn main(
let valid_signature = std::schnorr::verify_signature(pub_key_x, pub_key_y, signature, message);
assert(valid_signature);
let pub_key = embedded_curve_ops::EmbeddedCurvePoint { x: pub_key_x, y: pub_key_y, is_infinite: false };
let valid_signature = verify_signature_noir(pub_key, signature, message2);
let valid_signature = std::schnorr::verify_signature_noir(pub_key, signature, message);
assert(valid_signature);
assert_valid_signature(pub_key, signature, message2);
}

// TODO: to put in the stdlib once we have numeric generics
// Meanwhile, you have to use a message with 32 additional bytes:
// If you want to verify a signature on a message of 10 bytes, you need to pass a message of length 42,
// where the first 10 bytes are the one from the original message (the other bytes are not used)
pub fn verify_signature_noir<let M: u32>(
public_key: embedded_curve_ops::EmbeddedCurvePoint,
signature: [u8; 64],
message: [u8; M]
) -> bool {
let N = message.len() - 32;

//scalar lo/hi from bytes
let sig_s = bytes_to_scalar(signature, 0);
let sig_e = bytes_to_scalar(signature, 32);
// pub_key is on Grumpkin curve
let mut is_ok = (public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17)
& (!public_key.is_infinite);

if ((sig_s.lo != 0) | (sig_s.hi != 0)) & ((sig_e.lo != 0) | (sig_e.hi != 0)) {
let g1 = embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false };
let r = embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]);
// compare the _hashes_ rather than field elements modulo r
let pedersen_hash = std::hash::pedersen_hash([r.x, public_key.x, public_key.y]);
let mut hash_input = [0; M];
let pde: [u8; 32] = pedersen_hash.to_be_bytes();

for i in 0..32 {
hash_input[i] = pde[i];
}
for i in 0..N {
hash_input[32+i] = message[i];
}
let result = std::hash::blake2s(hash_input);

is_ok = !r.is_infinite;
for i in 0..32 {
if result[i] != signature[32 + i] {
is_ok = false;
}
}
}
is_ok
}

pub fn bytes_to_scalar(bytes: [u8; 64], offset: u32) -> embedded_curve_ops::EmbeddedCurveScalar {
let mut v = 1;
let mut lo = 0 as Field;
let mut hi = 0 as Field;
for i in 0..16 {
lo = lo + (bytes[offset+31 - i] as Field) * v;
hi = hi + (bytes[offset+15 - i] as Field) * v;
v = v * 256;
}
let sig_s = embedded_curve_ops::EmbeddedCurveScalar { lo, hi };
sig_s
}

pub fn assert_valid_signature<let M: u32>(
public_key: embedded_curve_ops::EmbeddedCurvePoint,
signature: [u8; 64],
message: [u8; M]
) {
let N = message.len() - 32;
//scalar lo/hi from bytes
let sig_s = bytes_to_scalar(signature, 0);
let sig_e = bytes_to_scalar(signature, 32);

// assert pub_key is on Grumpkin curve
assert(public_key.y * public_key.y == public_key.x * public_key.x * public_key.x - 17);
assert(public_key.is_infinite == false);
// assert signature is not null
assert((sig_s.lo != 0) | (sig_s.hi != 0));
assert((sig_e.lo != 0) | (sig_e.hi != 0));

let g1 = embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false };
let r = embedded_curve_ops::multi_scalar_mul([g1, public_key], [sig_s, sig_e]);
// compare the _hashes_ rather than field elements modulo r
let pedersen_hash = std::hash::pedersen_hash([r.x, public_key.x, public_key.y]);
let mut hash_input = [0; M];
let pde: [u8; 32] = pedersen_hash.to_be_bytes();

for i in 0..32 {
hash_input[i] = pde[i];
}
for i in 0..N {
hash_input[32+i] = message[i];
}
let result = std::hash::blake2s(hash_input);

assert(!r.is_infinite);
for i in 0..32 {
assert(result[i] == signature[32 + i]);
}
std::schnorr::assert_valid_signature(pub_key, signature, message);
}

0 comments on commit c50fb94

Please sign in to comment.