diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index b5d2624d3e7..a3ac0b9e5da 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -12,6 +12,20 @@ global MSG_SIZE_PTR = 56; // Size of the message block when packed as 4-byte integer array. global INT_BLOCK_SIZE = 16; +// A `u32` integer consists of 4 bytes. +global INT_SIZE = 4; + +// Index of the integer in the `INT_BLOCK` where the length is written. +global INT_SIZE_PTR = MSG_SIZE_PTR / INT_SIZE; + +// Magic numbers for bit shifting. +// Works with actual bit shifting as well as the compiler turns them into * and / +// but circuit execution appears to be 10% faster this way. +global TWO_POW_8 = 256; +global TWO_POW_16 = TWO_POW_8 * 256; +global TWO_POW_24 = TWO_POW_16 * 256; +global TWO_POW_32 = TWO_POW_24 as u64 * 256; + // Index of a byte in a 64 byte block; ie. 0..=63 type BLOCK_BYTE_PTR = u32; @@ -19,8 +33,8 @@ type BLOCK_BYTE_PTR = u32; type INT_BLOCK = [u32; INT_BLOCK_SIZE]; // A message block is a slice of the original message of a fixed size, -// potentially padded with zeroes. -type MSG_BLOCK = [u8; BLOCK_SIZE]; +// potentially padded with zeros, with neighbouring 4 bytes packed into integers. +type MSG_BLOCK = INT_BLOCK; // The hash is 32 bytes. type HASH = [u8; 32]; @@ -50,16 +64,19 @@ pub fn digest(msg: [u8; N]) -> HASH { pub fn sha256_var(msg: [u8; N], message_size: u64) -> HASH { let message_size = message_size as u32; let num_blocks = N / BLOCK_SIZE; - let mut msg_block: MSG_BLOCK = [0; BLOCK_SIZE]; + let mut msg_block: MSG_BLOCK = [0; INT_BLOCK_SIZE]; + // Intermediate hash, starting with the canonical initial value let mut h: STATE = [ 1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225, - ]; // Intermediate hash, starting with the canonical initial value - let mut msg_byte_ptr = 0; // Pointer into msg_block + ]; + // Pointer into msg_block on a 64 byte scale + let mut msg_byte_ptr = 0; for i in 0..num_blocks { let msg_start = BLOCK_SIZE * i; let (new_msg_block, new_msg_byte_ptr) = unsafe { build_msg_block(msg, message_size, msg_start) }; + if msg_start < message_size { msg_block = new_msg_block; } @@ -77,7 +94,7 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> HASH { // If the block is filled, compress it. // An un-filled block is handled after this loop. if (msg_start < message_size) & (msg_byte_ptr == BLOCK_SIZE) { - h = sha256_compression(msg_u8_to_u32(msg_block), h); + h = sha256_compression(msg_block, h); } } @@ -114,49 +131,39 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> HASH { // Pad the rest such that we have a [u32; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). // Here we rely on the fact that everything beyond the available input is set to 0. - msg_block[msg_byte_ptr] = 1 << 7; - let last_block = msg_block; + msg_block = update_block_item( + msg_block, + msg_byte_ptr, + |msg_item| set_item_byte_then_zeros(msg_item, msg_byte_ptr, 1 << 7), + ); msg_byte_ptr = msg_byte_ptr + 1; + let last_block = msg_block; // If we don't have room to write the size, compress the block and reset it. if msg_byte_ptr > MSG_SIZE_PTR { - h = sha256_compression(msg_u8_to_u32(msg_block), h); + h = sha256_compression(msg_block, h); // `attach_len_to_msg_block` will zero out everything after the `msg_byte_ptr`. msg_byte_ptr = 0; } msg_block = unsafe { attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) }; - if !crate::runtime::is_unconstrained() { + if !is_unconstrained() { verify_msg_len(msg_block, last_block, msg_byte_ptr, message_size); } hash_final_block(msg_block, h) } -// Convert 64-byte array to array of 16 u32s -fn msg_u8_to_u32(msg: MSG_BLOCK) -> INT_BLOCK { - let mut msg32: INT_BLOCK = [0; INT_BLOCK_SIZE]; - - for i in 0..INT_BLOCK_SIZE { - let mut msg_field: Field = 0; - for j in 0..4 { - msg_field = msg_field * 256 + msg[64 - 4 * (i + 1) + j] as Field; - } - msg32[15 - i] = msg_field as u32; - } - - msg32 -} - // Take `BLOCK_SIZE` number of bytes from `msg` starting at `msg_start`. -// Returns the block and the length that has been copied rather than padded with zeroes. +// Returns the block and the length that has been copied rather than padded with zeros. unconstrained fn build_msg_block( msg: [u8; N], message_size: u32, msg_start: u32, ) -> (MSG_BLOCK, BLOCK_BYTE_PTR) { - let mut msg_block: MSG_BLOCK = [0; BLOCK_SIZE]; + let mut msg_block: MSG_BLOCK = [0; INT_BLOCK_SIZE]; + // We insert `BLOCK_SIZE` bytes (or up to the end of the message) let block_input = if msg_start + BLOCK_SIZE > message_size { if message_size < msg_start { @@ -169,29 +176,73 @@ unconstrained fn build_msg_block( } else { BLOCK_SIZE }; - for k in 0..block_input { - msg_block[k] = msg[msg_start + k]; + + // Figure out the number of items in the int array that we have to pack. + // e.g. if the input is [0,1,2,3,4,5] then we need to pack it as 2 items: [0123, 4500] + let mut int_input = block_input / INT_SIZE; + if block_input % INT_SIZE != 0 { + int_input = int_input + 1; + }; + + for i in 0..int_input { + let mut msg_item: u32 = 0; + // Always construct the integer as 4 bytes, even if it means going beyond the input. + for j in 0..INT_SIZE { + let k = i * INT_SIZE + j; + let msg_byte = if k < block_input { + msg[msg_start + k] + } else { + 0 + }; + msg_item = lshift8(msg_item, 1) + msg_byte as u32; + } + msg_block[i] = msg_item; } + + // Returning the index as if it was a 64 byte array. + // We have to project it down to 16 items and bit shifting to get a byte back if we need it. (msg_block, block_input) } // Verify the block we are compressing was appropriately constructed by `build_msg_block` // and matches the input data. Returns the index of the first unset item. +// If `message_size` is less than `msg_start` then this is called with the old non-empty block; +// in that case we can skip verification, ie. no need to check that everything is zero. fn verify_msg_block( msg: [u8; N], message_size: u32, msg_block: MSG_BLOCK, msg_start: u32, ) -> BLOCK_BYTE_PTR { - let mut msg_byte_ptr: u32 = 0; // Message byte pointer + let mut msg_byte_ptr = 0; let mut msg_end = msg_start + BLOCK_SIZE; if msg_end > N { msg_end = N; } + // We might have to go beyond the input to pad the fields. + if msg_end % INT_SIZE != 0 { + msg_end = msg_end + INT_SIZE - msg_end % INT_SIZE; + } - for k in msg_start..msg_end { - if k < message_size { - assert_eq(msg_block[msg_byte_ptr], msg[k]); + // Reconstructed packed item. + let mut msg_item: u32 = 0; + + // Inclusive at the end so that we can compare the last item. + let mut i: u32 = 0; + for k in msg_start..=msg_end { + if k % INT_SIZE == 0 { + // If we consumed some input we can compare against the block. + if (msg_start < message_size) & (k > msg_start) { + assert_eq(msg_block[i], msg_item as u32); + i = i + 1; + msg_item = 0; + } + } + // Shift the accumulator + msg_item = lshift8(msg_item, 1); + // If we have input to consume, add it at the rightmost position. + if k < message_size & k < msg_end { + msg_item = msg_item + msg[k] as u32; msg_byte_ptr = msg_byte_ptr + 1; } } @@ -199,69 +250,261 @@ fn verify_msg_block( msg_byte_ptr } -// Verify the block we are compressing was appropriately padded with zeroes by `build_msg_block`. +// Verify the block we are compressing was appropriately padded with zeros by `build_msg_block`. // This is only relevant for the last, potentially partially filled block. fn verify_msg_block_padding(msg_block: MSG_BLOCK, msg_byte_ptr: BLOCK_BYTE_PTR) { + // Check all the way to the end of the block. + verify_msg_block_zeros(msg_block, msg_byte_ptr, INT_BLOCK_SIZE); +} + +// Verify that a region of ints in the message block are (partially) zeroed, +// up to an (exclusive) maximum which can either be the end of the block +// or just where the size is to be written. +fn verify_msg_block_zeros( + msg_block: MSG_BLOCK, + mut msg_byte_ptr: BLOCK_BYTE_PTR, + max_int_byte_ptr: u32, +) { // This variable is used to get around the compiler under-constrained check giving a warning. // We want to check against a constant zero, but if it does not come from the circuit inputs // or return values the compiler check will issue a warning. let zero = msg_block[0] - msg_block[0]; - for i in 0..BLOCK_SIZE { - if i >= msg_byte_ptr { + // First integer which is supposed to be (partially) zero. + let mut int_byte_ptr = msg_byte_ptr / INT_SIZE; + + // Check partial zeros. + let modulo = msg_byte_ptr % INT_SIZE; + if modulo != 0 { + let zeros = INT_SIZE - modulo; + let mask = if zeros == 3 { + TWO_POW_24 + } else if zeros == 2 { + TWO_POW_16 + } else { + TWO_POW_8 + }; + assert_eq(msg_block[int_byte_ptr] % mask, zero); + int_byte_ptr = int_byte_ptr + 1; + } + + // Check the rest of the items. + for i in 0..max_int_byte_ptr { + if i >= int_byte_ptr { assert_eq(msg_block[i], zero); } } } +// Verify that up to the byte pointer the two blocks are equal. +// At the byte pointer the new block can be partially zeroed. +fn verify_msg_block_equals_last( + msg_block: MSG_BLOCK, + last_block: MSG_BLOCK, + mut msg_byte_ptr: BLOCK_BYTE_PTR, +) { + // msg_byte_ptr is the position at which they are no longer have to be the same. + // First integer which is supposed to be (partially) zero contains that pointer. + let mut int_byte_ptr = msg_byte_ptr / INT_SIZE; + + // Check partial zeros. + let modulo = msg_byte_ptr % INT_SIZE; + if modulo != 0 { + // Reconstruct the partially zero item from the last block. + let last_field = last_block[int_byte_ptr]; + let mut msg_item: u32 = 0; + // Reset to where they are still equal. + msg_byte_ptr = msg_byte_ptr - modulo; + for i in 0..INT_SIZE { + msg_item = lshift8(msg_item, 1); + if i < modulo { + msg_item = msg_item + get_item_byte(last_field, msg_byte_ptr) as u32; + msg_byte_ptr = msg_byte_ptr + 1; + } + } + assert_eq(msg_block[int_byte_ptr], msg_item); + } + + for i in 0..INT_SIZE_PTR { + if i < int_byte_ptr { + assert_eq(msg_block[i], last_block[i]); + } + } +} + +// Apply a function on the block item which the pointer indicates. +fn update_block_item( + mut msg_block: MSG_BLOCK, + msg_byte_ptr: BLOCK_BYTE_PTR, + f: fn[Env](u32) -> u32, +) -> MSG_BLOCK { + let i = msg_byte_ptr / INT_SIZE; + msg_block[i] = f(msg_block[i]); + msg_block +} + +// Set the rightmost `zeros` number of bytes to 0. +fn set_item_zeros(item: u32, zeros: u8) -> u32 { + lshift8(rshift8(item, zeros), zeros) +} + +// Replace one byte in the item with a value, and set everything after it to zero. +fn set_item_byte_then_zeros(msg_item: u32, msg_byte_ptr: BLOCK_BYTE_PTR, msg_byte: u8) -> u32 { + let zeros = INT_SIZE - msg_byte_ptr % INT_SIZE; + let zeroed_item = set_item_zeros(msg_item, zeros as u8); + let new_item = byte_into_item(msg_byte, msg_byte_ptr); + zeroed_item + new_item +} + +// Get a byte of a message item according to its overall position in the `BLOCK_SIZE` space. +fn get_item_byte(mut msg_item: u32, msg_byte_ptr: BLOCK_BYTE_PTR) -> u8 { + // How many times do we have to shift to the right to get to the position we want? + let max_shifts = INT_SIZE - 1; + let shifts = max_shifts - msg_byte_ptr % INT_SIZE; + msg_item = rshift8(msg_item, shifts as u8); + // At this point the byte we want is in the rightmost position. + msg_item as u8 +} + +// Project a byte into a position in a field based on the overall block pointer. +// For example putting 1 into pointer 5 would be 100, because overall we would +// have [____, 0100] with indexes [0123,4567]. +fn byte_into_item(msg_byte: u8, msg_byte_ptr: BLOCK_BYTE_PTR) -> u32 { + let mut msg_item = msg_byte as u32; + // How many times do we have to shift to the left to get to the position we want? + let max_shifts = INT_SIZE - 1; + let shifts = max_shifts - msg_byte_ptr % INT_SIZE; + lshift8(msg_item, shifts as u8) +} + +// Construct a field out of 4 bytes. +fn make_item(b0: u8, b1: u8, b2: u8, b3: u8) -> u32 { + let mut item = b0 as u32; + item = lshift8(item, 1) + b1 as u32; + item = lshift8(item, 1) + b2 as u32; + item = lshift8(item, 1) + b3 as u32; + item +} + +// Shift by 8 bits to the left between 0 and 4 times. +// Checks `is_unconstrained()` to just use a bitshift if we're running in an unconstrained context, +// otherwise multiplies by 256. +fn lshift8(item: u32, shifts: u8) -> u32 { + if is_unconstrained() { + if item == 0 { + 0 + } else { + // Brillig wouldn't shift 0<<4 without overflow. + item << (8 * shifts) + } + } else { + // We can do a for loop up to INT_SIZE or an if-else. + if shifts == 0 { + item + } else if shifts == 1 { + item * TWO_POW_8 + } else if shifts == 2 { + item * TWO_POW_16 + } else if shifts == 3 { + item * TWO_POW_24 + } else { + // Doesn't make sense, but it's most likely called on 0 anyway. + 0 + } + } +} + +// Shift by 8 bits to the right between 0 and 4 times. +// Checks `is_unconstrained()` to just use a bitshift if we're running in an unconstrained context, +// otherwise divides by 256. +fn rshift8(item: u32, shifts: u8) -> u32 { + if is_unconstrained() { + item >> (8 * shifts) + } else { + // Division wouldn't work on `Field`. + if shifts == 0 { + item + } else if shifts == 1 { + item / TWO_POW_8 + } else if shifts == 2 { + item / TWO_POW_16 + } else if shifts == 3 { + item / TWO_POW_24 + } else { + 0 + } + } +} + // Zero out all bytes between the end of the message and where the length is appended, // then write the length into the last 8 bytes of the block. unconstrained fn attach_len_to_msg_block( mut msg_block: MSG_BLOCK, - msg_byte_ptr: BLOCK_BYTE_PTR, + mut msg_byte_ptr: BLOCK_BYTE_PTR, message_size: u32, ) -> MSG_BLOCK { // We assume that `msg_byte_ptr` is less than 57 because if not then it is reset to zero before calling this function. - // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). - for i in msg_byte_ptr..MSG_SIZE_PTR { + // In any case, fill blocks up with zeros until the last 64 bits (i.e. until msg_byte_ptr = 56). + // There can be one item which has to be partially zeroed. + let modulo = msg_byte_ptr % INT_SIZE; + if modulo != 0 { + // Index of the block in which we find the item we need to partially zero. + let i = msg_byte_ptr / INT_SIZE; + let zeros = INT_SIZE - modulo; + msg_block[i] = set_item_zeros(msg_block[i], zeros as u8); + msg_byte_ptr = msg_byte_ptr + zeros; + } + + // The rest can be zeroed without bit shifting anything. + for i in (msg_byte_ptr / INT_SIZE)..INT_SIZE_PTR { msg_block[i] = 0; } + // Set the last two 4 byte ints as the first/second half of the 8 bytes of the length. let len = 8 * message_size; let len_bytes: [u8; 8] = (len as Field).to_be_bytes(); - for i in 0..8 { - msg_block[MSG_SIZE_PTR + i] = len_bytes[i]; + for i in 0..=1 { + let shift = i * 4; + msg_block[INT_SIZE_PTR + i] = make_item( + len_bytes[shift], + len_bytes[shift + 1], + len_bytes[shift + 2], + len_bytes[shift + 3], + ); } msg_block } -// Verify that the message length was correctly written by `attach_len_to_msg_block`. +// Verify that the message length was correctly written by `attach_len_to_msg_block`, +// and that everything between the byte pointer and the size pointer was zeroed, +// and that everything before the byte pointer was untouched. fn verify_msg_len( msg_block: MSG_BLOCK, last_block: MSG_BLOCK, msg_byte_ptr: BLOCK_BYTE_PTR, message_size: u32, ) { - for i in 0..MSG_SIZE_PTR { - let predicate = (i < msg_byte_ptr) as u8; - let expected_byte = predicate * last_block[i]; - assert_eq(msg_block[i], expected_byte); - } + // Check zeros up to the size pointer. + verify_msg_block_zeros(msg_block, msg_byte_ptr, INT_SIZE_PTR); + + // Check that up to the pointer we match the last block. + verify_msg_block_equals_last(msg_block, last_block, msg_byte_ptr); // We verify the message length was inserted correctly by reversing the byte decomposition. - let len = 8 * message_size; - let mut reconstructed_len: Field = 0; - for i in MSG_SIZE_PTR..BLOCK_SIZE { - reconstructed_len = 256 * reconstructed_len + msg_block[i] as Field; + let mut reconstructed_len: u64 = 0; + for i in INT_SIZE_PTR..INT_BLOCK_SIZE { + reconstructed_len = reconstructed_len * TWO_POW_32; + reconstructed_len = reconstructed_len + msg_block[i] as u64; } - assert_eq(reconstructed_len, len as Field); + let len = 8 * message_size as u64; + assert_eq(reconstructed_len, len); } // Perform the final compression, then transform the `STATE` into `HASH`. fn hash_final_block(msg_block: MSG_BLOCK, mut state: STATE) -> HASH { let mut out_h: HASH = [0; 32]; // Digest as sequence of bytes // Hash final padded block - state = sha256_compression(msg_u8_to_u32(msg_block), state); + state = sha256_compression(msg_block, state); // Return final hash as byte array for j in 0..8 { @@ -275,6 +518,11 @@ fn hash_final_block(msg_block: MSG_BLOCK, mut state: STATE) -> HASH { } mod tests { + use super::{ + attach_len_to_msg_block, build_msg_block, byte_into_item, get_item_byte, lshift8, make_item, + set_item_byte_then_zeros, set_item_zeros, + }; + use super::{INT_BLOCK, INT_BLOCK_SIZE, MSG_BLOCK}; use super::sha256_var; #[test] @@ -510,4 +758,79 @@ mod tests { assert_eq(var_length_hash_575, fixed_length_hash); assert_eq(var_length_hash_576, fixed_length_hash); } + + #[test] + fn test_get_item_byte() { + let fld = make_item(10, 20, 30, 40); + assert_eq(fld, 0x0a141e28); + assert_eq(get_item_byte(fld, 0), 10); + assert_eq(get_item_byte(fld, 4), 10); + assert_eq(get_item_byte(fld, 6), 30); + } + + #[test] + fn test_byte_into_item() { + let fld = make_item(0, 20, 0, 0); + assert_eq(byte_into_item(20, 1), fld); + assert_eq(byte_into_item(20, 5), fld); + } + + #[test] + fn test_set_item_zeros() { + let fld0 = make_item(10, 20, 30, 40); + let fld1 = make_item(10, 0, 0, 0); + assert_eq(set_item_zeros(fld0, 3), fld1); + assert_eq(set_item_zeros(fld0, 4), 0); + assert_eq(set_item_zeros(0, 4), 0); + } + + #[test] + fn test_set_item_byte_then_zeros() { + let fld0 = make_item(10, 20, 30, 40); + let fld1 = make_item(10, 50, 0, 0); + assert_eq(set_item_byte_then_zeros(fld0, 1, 50), fld1); + } + + #[test] + fn test_build_msg_block_start_0() { + let input = [ + 102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, + 101, 115, 46, 48, + ]; + assert_eq(input.len(), 22); + let (msg_block, msg_byte_ptr) = unsafe { build_msg_block(input, input.len(), 0) }; + assert_eq(msg_byte_ptr, input.len()); + assert_eq(msg_block[0], make_item(input[0], input[1], input[2], input[3])); + assert_eq(msg_block[1], make_item(input[4], input[5], input[6], input[7])); + assert_eq(msg_block[5], make_item(input[20], input[21], 0, 0)); + assert_eq(msg_block[6], 0); + } + + #[test] + fn test_build_msg_block_start_1() { + let input = [ + 102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, + 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, + 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, + 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, + ]; + assert_eq(input.len(), 68); + let (msg_block, msg_byte_ptr) = unsafe { build_msg_block(input, input.len(), 64) }; + assert_eq(msg_byte_ptr, 4); + assert_eq(msg_block[0], make_item(input[64], input[65], input[66], input[67])); + assert_eq(msg_block[1], 0); + } + + #[test] + fn test_attach_len_to_msg_block() { + let input: INT_BLOCK = [ + 2152555847, 1397309779, 1936618851, 1262052426, 1936876331, 1985297723, 543702374, + 1919905082, 1131376244, 1701737517, 1417244773, 978151789, 1697470053, 1920166255, + 1849316213, 1651139939, + ]; + let msg_block = unsafe { attach_len_to_msg_block(input, 1, 448) }; + assert_eq(msg_block[0], ((1 << 7) as u32) * 256 * 256 * 256); + assert_eq(msg_block[1], 0); + assert_eq(msg_block[15], 3584); + } } diff --git a/tooling/nargo_cli/tests/stdlib-tests.rs b/tooling/nargo_cli/tests/stdlib-tests.rs index d27a2961a62..46a241b7e0b 100644 --- a/tooling/nargo_cli/tests/stdlib-tests.rs +++ b/tooling/nargo_cli/tests/stdlib-tests.rs @@ -76,7 +76,7 @@ fn run_stdlib_tests() { &bn254_blackbox_solver::Bn254BlackBoxSolver, &mut context, &test_function, - false, + true, None, Some(dummy_package.root_dir.clone()), Some(dummy_package.name.to_string()),