diff --git a/triton-vm/src/aet.rs b/triton-vm/src/aet.rs index d4142ad38..b94219b5b 100644 --- a/triton-vm/src/aet.rs +++ b/triton-vm/src/aet.rs @@ -26,6 +26,7 @@ use crate::stark::StarkHasher; use crate::table::hash_table::HashTable; use crate::table::hash_table::PermutationTrace; use crate::table::op_stack_table::OpStackTableEntry; +use crate::table::ram_table::RamTableCall; use crate::table::table_column::HashBaseTableColumn::CI; use crate::table::table_column::MasterBaseTableColumn; use crate::table::u32_table::U32TableEntry; @@ -55,6 +56,8 @@ pub struct AlgebraicExecutionTrace { pub op_stack_underflow_trace: Array2, + pub ram_trace: Array2, + /// The trace of hashing the program whose execution generated this `AlgebraicExecutionTrace`. /// The resulting digest /// 1. ties a [`Proof`](crate::proof::Proof) to the program it was produced from, and @@ -90,6 +93,7 @@ impl AlgebraicExecutionTrace { instruction_multiplicities: vec![0_u32; program_len], processor_trace: Array2::default([0, processor_table::BASE_WIDTH]), op_stack_underflow_trace: Array2::default([0, op_stack_table::BASE_WIDTH]), + ram_trace: Array2::default([0, ram_table::BASE_WIDTH]), program_hash_trace: Array2::default([0, hash_table::BASE_WIDTH]), hash_trace: Array2::default([0, hash_table::BASE_WIDTH]), sponge_trace: Array2::default([0, hash_table::BASE_WIDTH]), @@ -106,6 +110,7 @@ impl AlgebraicExecutionTrace { self.program_table_length(), self.processor_table_length(), self.op_stack_table_length(), + self.ram_table_length(), self.hash_table_length(), self.cascade_table_length(), self.lookup_table_length(), @@ -185,6 +190,10 @@ impl AlgebraicExecutionTrace { self.op_stack_underflow_trace.nrows() } + pub fn ram_table_length(&self) -> usize { + self.ram_trace.nrows() + } + pub fn hash_table_length(&self) -> usize { self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows() } @@ -232,6 +241,7 @@ impl AlgebraicExecutionTrace { Tip5Trace(instruction, trace) => self.append_sponge_trace(instruction, *trace), U32Call(u32_entry) => self.record_u32_table_entry(u32_entry), OpStackCall(op_stack_entry) => self.record_op_stack_entry(op_stack_entry), + RamCall(ram_call) => self.record_ram_call(ram_call), } } @@ -322,6 +332,12 @@ impl AlgebraicExecutionTrace { .push_row(op_stack_table_row.view()) .unwrap(); } + + fn record_ram_call(&mut self, ram_call: RamTableCall) { + self.ram_trace + .append(Axis(0), ram_call.to_table_rows().view()) + .unwrap(); + } } #[cfg(test)] diff --git a/triton-vm/src/example_programs.rs b/triton-vm/src/example_programs.rs index 54dcebf05..59837ad26 100644 --- a/triton-vm/src/example_programs.rs +++ b/triton-vm/src/example_programs.rs @@ -115,7 +115,7 @@ fn merkle_tree_authentication_path_verify() -> Program { read_io 1 // number of authentication paths to test // stack: [num] mt_ap_verify: // proper program starts here - push 0 swap 1 write_mem pop 1 // store number of APs at RAM address 0 + push 0 write_mem 1 pop 1 // store number of APs at RAM address 0 // stack: [] read_io 5 // read Merkle root // stack: [r4 r3 r2 r1 r0] @@ -128,12 +128,14 @@ fn merkle_tree_authentication_path_verify() -> Program { // stack before: [* r4 r3 r2 r1 r0] // stack after: [* r4 r3 r2 r1 r0] check_aps: - push 0 read_mem dup 0 // get number of APs left to check - // stack: [* r4 r3 r2 r1 r0 0 num_left num_left] + push 1 read_mem 1 pop 1 dup 0 // get number of APs left to check + // stack: [* r4 r3 r2 r1 r0 num_left num_left] push 0 eq // see if there are authentication paths left // stack: [* r4 r3 r2 r1 r0 0 num_left num_left==0] skiz return // return if no authentication paths left - push -1 add write_mem pop 1 // decrease number of authentication paths left to check + push -1 add // decrease number of authentication paths left to check + // stack: [* r4 r3 r2 r1 r0 num_left-1] + push 0 write_mem 1 pop 1 // write decreased number to address 0 // stack: [* r4 r3 r2 r1 r0] call get_idx_and_leaf // stack: [* r4 r3 r2 r1 r0 idx l4 l3 l2 l1 l0] @@ -173,75 +175,87 @@ fn merkle_tree_authentication_path_verify() -> Program { } fn verify_sudoku() -> Program { + // RAM layout: + // 0..=8: primes for mapping digits 1..=9 + // 9: flag for whether the Sudoku is valid + // 10..=90: the Sudoku grid + // + // 10 11 12 13 14 15 16 17 18 + // 19 20 21 22 23 24 25 26 27 + // 28 29 30 31 32 33 34 35 36 + // + // 37 38 39 40 41 42 43 44 45 + // 46 47 48 49 50 51 52 53 54 + // 55 56 57 58 59 60 61 62 63 + // + // 64 65 66 67 68 69 70 71 72 + // 73 74 75 76 77 78 79 80 81 + // 82 83 84 85 86 87 88 89 90 + triton_program!( + call initialize_flag call initialize_primes call read_sudoku - call initialize_flag call write_sudoku_and_check_rows call check_columns call check_squares - push 0 - read_mem - assert - halt + call assert_flag + + // For checking whether the Sudoku is valid. Initially `true`, set to `false` if any + // inconsistency is found. + initialize_flag: + push 1 // _ 1 + push 9 // _ 1 9 + write_mem 1 // _ 10 + pop 1 // _ + return + + invalidate_flag: + push 0 // _ 0 + push 9 // _ 0 9 + write_mem 1 // _ 10 + pop 1 // _ + return + + assert_flag: + push 10 // _ 10 + read_mem 1 // _ flag 9 + pop 1 // _ flag + assert // _ + halt // For mapping legal Sudoku digits to distinct primes. Helps with checking consistency of // rows, columns, and boxes. initialize_primes: - push 1 push 2 write_mem - push 2 push 3 write_mem - push 3 push 5 write_mem - push 4 push 7 write_mem - push 5 push 11 write_mem - push 6 push 13 write_mem - push 7 push 17 write_mem - push 8 push 19 write_mem - push 9 push 23 write_mem - pop 5 pop 4 + push 23 push 19 push 17 + push 13 push 11 push 7 + push 5 push 3 push 2 + push 0 write_mem 5 write_mem 4 + pop 1 return read_sudoku: - call read9 - call read9 - call read9 - call read9 - call read9 - call read9 - call read9 - call read9 - call read9 + call read9 call read9 call read9 + call read9 call read9 call read9 + call read9 call read9 call read9 return read9: - call read1 - call read1 - call read1 - call read1 - call read1 - call read1 - call read1 - call read1 - call read1 + call read1 call read1 call read1 + call read1 call read1 call read1 + call read1 call read1 call read1 return // Applies the mapping from legal Sudoku digits to distinct primes. read1: // _ read_io 1 // _ d - read_mem // _ d p - swap 1 // _ p d + read_mem 1 // _ p d-1 pop 1 // _ p return - initialize_flag: - push 0 - push 1 - write_mem - pop 1 - return - write_sudoku_and_check_rows: // row0 row1 row2 row3 row4 row5 row6 row7 row8 - push 9 // row0 row1 row2 row3 row4 row5 row6 row7 row8 9 - call write_and_check_one_row // row0 row1 row2 row3 row4 row5 row6 row7 18 + push 10 // row0 row1 row2 row3 row4 row5 row6 row7 row8 10 + call write_and_check_one_row // row0 row1 row2 row3 row4 row5 row6 row7 19 call write_and_check_one_row // row0 row1 row2 row3 row4 row5 row6 27 call write_and_check_one_row // row0 row1 row2 row3 row4 row5 36 call write_and_check_one_row // row0 row1 row2 row3 row4 45 @@ -253,159 +267,60 @@ fn verify_sudoku() -> Program { pop 1 // ⊥ return - write_and_check_one_row: // s0 s1 s2 s3 s4 s5 s6 s7 s8 mem_addr - push 1 // s0 s1 s2 s3 s4 s5 s6 s7 s8 mem_addr 1 - call multiply_and_write // s0 s1 s2 s3 s4 s5 s6 s7 (mem_addr+1) s8 - call multiply_and_write // s0 s1 s2 s3 s4 s5 s6 (mem_addr+2) (s8·s7) - call multiply_and_write // s0 s1 s2 s3 s4 s5 (mem_addr+3) (s8·s7·s6) - call multiply_and_write // s0 s1 s2 s3 s4 (mem_addr+4) (s8·s7·s6·s5) - call multiply_and_write // s0 s1 s2 s3 (mem_addr+5) (s8·s7·s6·s5·s4) - call multiply_and_write // s0 s1 s2 (mem_addr+6) (s8·s7·s6·s5·s4·s3) - call multiply_and_write // s0 s1 (mem_addr+7) (s8·s7·s6·s5·s4·s3·s2) - call multiply_and_write // s0 (mem_addr+8) (s8·s7·s6·s5·s4·s3·s2·s1) - call multiply_and_write // (mem_addr+9) (s8·s7·s6·s5·s4·s3·s2·s1·s0) - push 223092870 // (mem_addr+9) (s8·s7·s6·s5·s4·s3·s2·s1·s0) 223092870 - eq // (mem_addr+9) (s8·s7·s6·s5·s4·s3·s2·s1·s0==223092870) - skiz // (mem_addr+9) - return - push 0 // (mem_addr+9) 0 - push 0 // (mem_addr+9) 0 0 - write_mem // (mem_addr+9) 0 - pop 1 // (mem_addr+9) - return - - multiply_and_write: // s mem_addr acc - dup 2 // s mem_addr acc s - mul // s mem_addr (acc·s) - swap 1 // s (acc·s) mem_addr - push 1 // s (acc·s) mem_addr 1 - add // s (acc·s) (mem_addr+1) - swap 1 // s (mem_addr+1) (acc·s) - swap 2 // (acc·s) (mem_addr+1) s - write_mem // (acc·s) (mem_addr+1) - swap 1 // (mem_addr+1) (acc·s) + write_and_check_one_row: // row addr + dup 9 dup 9 dup 9 + dup 9 dup 9 dup 9 + dup 9 dup 9 dup 9 // row addr row + call check_9_numbers // row addr + write_mem 5 write_mem 4 // addr+9 return check_columns: - push 1 - call check_one_column - push 2 - call check_one_column - push 3 - call check_one_column - push 4 - call check_one_column - push 5 - call check_one_column - push 6 - call check_one_column - push 7 - call check_one_column - push 8 - call check_one_column - push 9 - call check_one_column + push 83 call check_one_column + push 84 call check_one_column + push 85 call check_one_column + push 86 call check_one_column + push 87 call check_one_column + push 88 call check_one_column + push 89 call check_one_column + push 90 call check_one_column + push 91 call check_one_column return check_one_column: - call get_column_element - call get_column_element - call get_column_element - call get_column_element - call get_column_element - call get_column_element - call get_column_element - call get_column_element - call get_column_element - pop 1 + read_mem 1 push -8 add read_mem 1 push -8 add read_mem 1 push -8 add + read_mem 1 push -8 add read_mem 1 push -8 add read_mem 1 push -8 add + read_mem 1 push -8 add read_mem 1 push -8 add read_mem 1 pop 1 call check_9_numbers return - get_column_element: - push 9 - add - read_mem - swap 1 - return - check_squares: - push 10 - call check_one_square - push 13 - call check_one_square - push 16 - call check_one_square - push 37 - call check_one_square - push 40 - call check_one_square - push 43 - call check_one_square - push 64 - call check_one_square - push 67 - call check_one_square - push 70 - call check_one_square + push 31 call check_one_square + push 34 call check_one_square + push 37 call check_one_square + push 58 call check_one_square + push 61 call check_one_square + push 64 call check_one_square + push 85 call check_one_square + push 88 call check_one_square + push 91 call check_one_square return check_one_square: - read_mem - swap 1 - push 1 - add - read_mem - swap 1 - push 1 - add - read_mem - swap 1 - push 7 - add - read_mem - swap 1 - push 1 - add - read_mem - swap 1 - push 1 - add - read_mem - swap 1 - push 7 - add - read_mem - swap 1 - push 1 - add - read_mem - swap 1 - push 1 - add - read_mem - swap 1 - pop 1 + read_mem 3 push -6 add + read_mem 3 push -6 add + read_mem 3 pop 1 call check_9_numbers return check_9_numbers: - mul - mul - mul - mul - mul - mul - mul - mul + mul mul mul + mul mul mul + mul mul // 223092870 = 2·3·5·7·11·13·17·19·23 - push 223092870 - eq - skiz - return - push 0 - push 0 - write_mem - pop 1 + push 223092870 eq + skiz return + call invalidate_flag return ) } @@ -413,518 +328,390 @@ fn verify_sudoku() -> Program { pub(crate) fn calculate_new_mmr_peaks_from_append_with_safe_lists() -> Program { triton_program!( // Stack and memory setup - push 0 - push 3 - push 1 - push 457470286889025784 - push 4071246825597671119 + push 0 // _ 0 + push 3 // _ 0 3 + push 1 // _ 0 3 1 + + push 00457470286889025784 + push 04071246825597671119 push 17834064596403781463 push 17484910066710486708 - push 6700794775299091393 - push 6 - push 02628975953172153832 - write_mem - push 10 - push 01807330184488272967 - write_mem - push 12 + push 06700794775299091393 // _ 0 3 1 [digest] + push 06595477061838874830 - write_mem - push 1 - push 2 - write_mem - push 11 push 10897391716490043893 - write_mem - push 7 + push 01807330184488272967 + push 05415221245149797169 + push 05057320540678713304 // _ 0 3 1 [digest] [digest] + push 01838589939278841373 - write_mem - push 8 - push 05057320540678713304 - write_mem - push 4 - push 00880730500905369322 - write_mem - push 5 + push 02628975953172153832 push 06845409670928290394 - write_mem - push 3 - push 04594396536654736100 - write_mem - push 2 - push 64 - write_mem - push 9 - push 05415221245149797169 - write_mem - push 0 - push 323 - write_mem - pop 5 pop 5 pop 3 + push 00880730500905369322 + push 04594396536654736100 // _ 0 3 1 [digest] [digest] [digest] - // Call the main function, followed by `halt` - call tasm_mmr_calculate_new_peaks_from_append_safe - halt + push 64 // _ 0 3 1 [digest] [digest] [digest] 64 + push 2 // _ 0 3 1 [digest] [digest] [digest] 64 2 + push 323 // _ 0 3 1 [digest] [digest] [digest] 64 2 323 - // Main function declaration - // BEFORE: _ old_leaf_count_hi old_leaf_count_lo *peaks [digests (new_leaf)] - // AFTER: _ *new_peaks *auth_path - tasm_mmr_calculate_new_peaks_from_append_safe: - dup 5 dup 5 dup 5 dup 5 dup 5 dup 5 - call tasm_list_safe_u32_push_digest - pop 5 - // stack: _ old_leaf_count_hi old_leaf_count_lo *peaks + push 0 // _ 0 3 1 [digest] [digest] [digest] 64 2 323 0 + write_mem 3 // _ 0 3 1 [digest] [digest] [digest] 3 + write_mem 5 // _ 0 3 1 [digest] [digest] 8 + write_mem 5 // _ 0 3 1 [digest] 13 + pop 1 // _ 0 3 1 [digest] - // Create auth_path return value (vector living in RAM) - push 64 // All MMR auth paths have capacity for 64 digests - call tasm_list_safe_u32_new_digest + call tasm_mmr_calculate_new_peaks_from_append_safe + halt - swap 1 - // stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks + // Main function + // BEFORE: _ old_leaf_count_hi old_leaf_count_lo *peaks [digest] + // AFTER: _ *new_peaks *auth_path + tasm_mmr_calculate_new_peaks_from_append_safe: + dup 5 dup 5 dup 5 dup 5 dup 5 dup 5 + call tasm_list_safe_u32_push_digest + pop 5 // _ old_leaf_count_hi old_leaf_count_lo *peaks - dup 3 dup 3 - // stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks old_leaf_count_hi old_leaf_count_lo + // Create auth_path return value (vector living in RAM) + // All MMR auth paths have capacity for 64 digests + push 64 // _ old_leaf_count_hi old_leaf_count_lo *peaks 64 + call tasm_list_safe_u32_new_digest - call tasm_arithmetic_u64_incr - call tasm_arithmetic_u64_index_of_last_nonzero_bit + swap 1 + // stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks - call tasm_mmr_calculate_new_peaks_from_append_safe_while - // stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks (rll = 0) + dup 3 dup 3 + // stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks old_leaf_count_hi old_leaf_count_lo - pop 1 - swap 3 pop 1 swap 1 pop 1 - // stack: _ *peaks *auth_path + call tasm_arithmetic_u64_incr + call tasm_arithmetic_u64_index_of_last_nonzero_bit - return + call tasm_mmr_calculate_new_peaks_from_append_safe_while + // stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks (rll = 0) - // Stack start and end: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks rll - tasm_mmr_calculate_new_peaks_from_append_safe_while: - dup 0 - push 0 - eq - skiz - return - // Stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks rll - - swap 2 swap 1 - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks - - dup 0 - dup 0 - call tasm_list_safe_u32_pop_digest - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] - - dup 5 - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] *peaks - - call tasm_list_safe_u32_pop_digest - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] [digests (previous_peak)] - - // Update authentication path with latest previous_peak - dup 12 - dup 5 dup 5 dup 5 dup 5 dup 5 - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] [digests (previous_peak)] *auth_path [digests (previous_peak)] - - call tasm_list_safe_u32_push_digest - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] [digests (previous_peak)] - - hash - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digests (new_peak)] - - call tasm_list_safe_u32_push_digest - // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks - - swap 1 swap 2 - // Stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks rll - - push -1 - add - // Stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks (rll - 1) - - recurse - - - // Before: _ value_hi value_lo - // After: _ (value + 1)_hi (value + 1)_lo - tasm_arithmetic_u64_incr_carry: - pop 1 - push 1 - add - dup 0 - push 4294967296 - eq - push 0 - eq - assert - push 0 - return + pop 1 + swap 3 pop 1 swap 1 pop 1 + // stack: _ *peaks *auth_path + + return - tasm_arithmetic_u64_incr: - push 1 - add - dup 0 - push 4294967296 - eq - skiz - call tasm_arithmetic_u64_incr_carry + // Stack start and end: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks rll + tasm_mmr_calculate_new_peaks_from_append_safe_while: + dup 0 + push 0 + eq + skiz return + // Stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks rll - // Before: _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0} - // After: _ - tasm_list_safe_u32_push_digest: - dup 5 - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, *list - - read_mem - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, *list, length - - // Verify that length < capacity (before increasing length by 1) - swap 1 - push 1 - add - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, length, (*list + 1) - - read_mem - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, length, (*list + 1), capacity - - dup 2 lt - // dup 2 eq - // push 0 eq - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, length, (*list + 1), capacity > length - - assert - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, length, (*list + 1) - - swap 1 - - push 5 - mul - - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, (*list + 1), length * elem_size - - add - push 1 - add - // stack : _ *list, elem{N - 1}, elem{N - 2}, ..., elem{0}, (*list + length * elem_size + 2) -- top of stack is where we will store elements - - swap 1 - write_mem - push 1 - add - swap 1 - write_mem - push 1 - add - swap 1 - write_mem - push 1 - add - swap 1 - write_mem - push 1 - add - swap 1 - write_mem - - // stack : _ *list, address - - pop 1 - // stack : _ *list - - // Increase length indicator by one - read_mem - // stack : _ *list, length - - push 1 - add - // stack : _ *list, length + 1 - - write_mem - // stack : _ *list - - pop 1 - // stack : _ + swap 2 swap 1 + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks - return + dup 0 + dup 0 + call tasm_list_safe_u32_pop_digest + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] - tasm_list_safe_u32_new_digest: - // _ capacity + dup 5 + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] *peaks - // Convert capacity in number of elements to number of VM words required for that list - dup 0 - push 5 - mul + call tasm_list_safe_u32_pop_digest + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] [digests (previous_peak)] - // _ capacity (capacity_in_bfes) + // Update authentication path with latest previous_peak + dup 12 + dup 5 dup 5 dup 5 dup 5 dup 5 + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] [digests (previous_peak)] *auth_path [digests (previous_peak)] - push 2 - add - // _ capacity (words to allocate) + call tasm_list_safe_u32_push_digest + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digest (new_hash)] [digests (previous_peak)] - call tasm_memory_dyn_malloc - // _ capacity *list + hash + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks *peaks [digests (new_peak)] - // Write initial length = 0 to `*list` - push 0 - write_mem - // _ capacity *list + call tasm_list_safe_u32_push_digest + // Stack: _ old_leaf_count_hi old_leaf_count_lo rll *auth_path *peaks - // Write capactiy to memory location `*list + 1` - push 1 - add - // _ capacity (*list + 1) + swap 1 swap 2 + // Stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks rll - swap 1 - write_mem - // _ (*list + 1) capacity + push -1 + add + // Stack: _ old_leaf_count_hi old_leaf_count_lo *auth_path *peaks (rll - 1) - push -1 - add - // _ *list + recurse - return - tasm_arithmetic_u64_decr: - push -1 - add - dup 0 - push -1 - eq - skiz - call tasm_arithmetic_u64_decr_carry - return + // Before: _ value_hi value_lo + // After: _ (value + 1)_hi (value + 1)_lo + tasm_arithmetic_u64_incr_carry: + pop 1 + push 1 + add + dup 0 + push 4294967296 + eq + push 0 + eq + assert + push 0 + return - tasm_arithmetic_u64_decr_carry: - pop 1 - push -1 - add - dup 0 - push -1 - eq - push 0 - eq - assert - push 4294967295 - return + tasm_arithmetic_u64_incr: + push 1 + add + dup 0 + push 4294967296 + eq + skiz + call tasm_arithmetic_u64_incr_carry + return - // BEFORE: _ *list list_length - // AFTER: _ *list - tasm_list_safe_u32_set_length_digest: - // Verify that new length does not exceed capacity - dup 0 - dup 2 - push 1 - add - read_mem - // Stack: *list list_length list_length (*list + 1) capacity + // Before: _ *list, elem[4], elem[3], elem[2], elem[1], elem[0] + // After: _ + tasm_list_safe_u32_push_digest: + dup 5 // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list + push 2 add // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list+2 + read_mem 2 // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] capacity len *list + + // Verify that length < capacity + swap 2 // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list len capacity + dup 1 // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list len capacity len + lt // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list len capacity>len + assert // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list len + + // Adjust ram pointer + push 5 // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list len 5 + mul // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list 5·len + add // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list+5·len + push 2 // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list+5·len 2 + add // _ *list elem[4] elem[3] elem[2] elem[1] elem[0] *list+5·len+2 + + // Write all elements + write_mem 5 // _ *list *list+5·len+7 + + // Remove ram pointer + pop 1 // _ *list + + // Increase length indicator by one + push 1 add // _ *list+1 + read_mem 1 // _ len *list + swap 1 // _ *list len + push 1 // _ *list len 1 + add // _ *list len+1 + swap 1 // _ len+1 *list + write_mem 1 // _ *list+1 + pop 1 // _ + return - swap 1 - pop 1 - // Stack: *list list_length list_length capacity + // BEFORE: _ capacity + // AFTER: + tasm_list_safe_u32_new_digest: + // Convert capacity in number of elements to number of VM words required for that list + dup 0 // _ capacity capacity + push 5 // _ capacity capacity 5 + mul // _ capacity 5·capacity + // _ capacity capacity_in_bfes + push 2 // _ capacity capacity_in_bfes 2 + add // _ capacity capacity_in_bfes+2 + // _ capacity words_to_allocate + + call tasm_memory_dyn_malloc // _ capacity *list + + // Write initial length = 0 to `*list`, and capacity to `*list + 1` + push 0 // _ capacity *list 0 + swap 1 // _ capacity 0 *list + write_mem 2 // _ (*list+2) + push -2 // _ (*list+2) -2 + add // _ *list + return - lt - push 0 - eq - // Stack: *list list_length list_length <= capacity + tasm_arithmetic_u64_decr: + push -1 + add + dup 0 + push -1 + eq + skiz + call tasm_arithmetic_u64_decr_carry + return - assert - // Stack: *list list_length + tasm_arithmetic_u64_decr_carry: + pop 1 + push -1 + add + dup 0 + push -1 + eq + push 0 + eq + assert + push 4294967295 + return - write_mem - // Stack: *list + // BEFORE: _ value_hi value_lo + // AFTER: _ log2_floor(value) + tasm_arithmetic_u64_log_2_floor: + swap 1 + push 1 + dup 1 + // stack: _ value_lo value_hi 1 value_hi - return + skiz call tasm_arithmetic_u64_log_2_floor_then + skiz call tasm_arithmetic_u64_log_2_floor_else + // stack: _ log2_floor(value) - // BEFORE: _ value_hi value_lo - // AFTER: _ log2_floor(value) - tasm_arithmetic_u64_log_2_floor: - swap 1 - push 1 - dup 1 - // stack: _ value_lo value_hi 1 value_hi + return - skiz call tasm_arithmetic_u64_log_2_floor_then - skiz call tasm_arithmetic_u64_log_2_floor_else - // stack: _ log2_floor(value) + tasm_arithmetic_u64_log_2_floor_then: + // value_hi != 0 + // stack: _ value_lo value_hi 1 + swap 1 + swap 2 + pop 2 + // stack: _ value_hi - return + log_2_floor + push 32 + add + // stack: _ (log2_floor(value_hi) + 32) - tasm_arithmetic_u64_log_2_floor_then: - // value_hi != 0 - // stack: _ value_lo value_hi 1 - swap 1 - swap 2 - pop 2 - // stack: _ value_hi + push 0 + // stack: _ (log2_floor(value_hi) + 32) 0 - log_2_floor - push 32 - add - // stack: _ (log2_floor(value_hi) + 32) + return - push 0 - // stack: _ (log2_floor(value_hi) + 32) 0 + tasm_arithmetic_u64_log_2_floor_else: + // value_hi == 0 + // stack: _ value_lo value_hi + pop 1 + log_2_floor + return - return + // Before: _ *list + // After: _ elem{N - 1}, elem{N - 2}, ..., elem{0} + tasm_list_safe_u32_pop_digest: + push 1 add // _ *list+1 + read_mem 1 // _ len *list + + // Assert that length is not 0 + dup 1 // _ len *list len + push 0 // _ len *list len 0 + eq // _ len *list len==0 + push 0 // _ len *list len==0 0 + eq // _ len *list len!=0 + assert // _ len *list + + // Decrease length value by one and write back to memory + dup 1 // _ len *list len + push -1 // _ len *list len -1 + add // _ len *list len-1 + swap 1 // _ len len-1 *list + write_mem 1 // _ len *list+1 + + // Read elements + swap 1 // _ *list+1 len + push 5 // _ *list+1 len 5 + mul // _ *list+1 5·len + // _ *list+1 offset_for_last_element + add // _ *list+offset_for_last_element+1 + // _ address_for_last_element + read_mem 5 // _ [elements] address_for_last_element-5 + pop 1 // _ [elements] + return - tasm_arithmetic_u64_log_2_floor_else: - // value_hi == 0 - // stack: _ value_lo value_hi - pop 1 - log_2_floor - return + // BEFORE: rhs_hi rhs_lo lhs_hi lhs_lo + // AFTER: (rhs & lhs)_hi (rhs & lhs)_lo + tasm_arithmetic_u64_and: + swap 3 + and + // stack: _ lhs_lo rhs_lo (lhs_hi & rhs_hi) - // Before: _ *list - // After: _ elem{N - 1}, elem{N - 2}, ..., elem{0} - tasm_list_safe_u32_pop_digest: - read_mem - // stack : _ *list, length - - // Assert that length is not 0 - dup 0 - push 0 - eq - push 0 - eq - assert - // stack : _ *list, length - - // Decrease length value by one and write back to memory - swap 1 - dup 1 - push -1 - add - write_mem - swap 1 - // stack : _ *list initial_length - - push 5 - mul - - // stack : _ *list, (offset_for_last_element = (N * initial_length)) - - add - push 1 - add - // stack : _ address_for_last_element - - read_mem - swap 1 - push -1 - add - read_mem - swap 1 - push -1 - add - read_mem - swap 1 - push -1 - add - read_mem - swap 1 - push -1 - add - read_mem - swap 1 - - // Stack: _ [elements], address_for_last_unread_element - - pop 1 - // Stack: _ [elements] + swap 2 + and + // stack: _ (lhs_hi & rhs_hi) (rhs_lo & lhs_lo) - return + return - // BEFORE: rhs_hi rhs_lo lhs_hi lhs_lo - // AFTER: (rhs & lhs)_hi (rhs & lhs)_lo - tasm_arithmetic_u64_and: - swap 3 - and - // stack: _ lhs_lo rhs_lo (lhs_hi & rhs_hi) + // BEFORE: _ value_hi value_lo + // AFTER: _ index_of_last_non-zero_bit + tasm_arithmetic_u64_index_of_last_nonzero_bit: + dup 1 + dup 1 + // _ value_hi value_lo value_hi value_lo - swap 2 - and - // stack: _ (lhs_hi & rhs_hi) (rhs_lo & lhs_lo) + call tasm_arithmetic_u64_decr + // _ value_hi value_lo (value - 1)_hi (value - 1)_lo - return + push 4294967295 + push 4294967295 + // _ value_hi value_lo (value - 1)_hi (value - 1)_lo 0xFFFFFFFF 0xFFFFFFFF - // BEFORE: _ value_hi value_lo - // AFTER: _ index_of_last_non-zero_bit - tasm_arithmetic_u64_index_of_last_nonzero_bit: - dup 1 - dup 1 - // _ value_hi value_lo value_hi value_lo + call tasm_arithmetic_u64_xor + // _ value_hi value_lo ~(value - 1)_hi ~(value - 1)_lo - call tasm_arithmetic_u64_decr - // _ value_hi value_lo (value - 1)_hi (value - 1)_lo + call tasm_arithmetic_u64_and + // _ (value & ~(value - 1))_hi (value & ~(value - 1))_lo - push 4294967295 - push 4294967295 - // _ value_hi value_lo (value - 1)_hi (value - 1)_lo 0xFFFFFFFF 0xFFFFFFFF + // The above value is now a power of two in u64. Calling log2_floor on this + // value gives us the index we are looking for. + call tasm_arithmetic_u64_log_2_floor - call tasm_arithmetic_u64_xor - // _ value_hi value_lo ~(value - 1)_hi ~(value - 1)_lo + return - call tasm_arithmetic_u64_and - // _ (value & ~(value - 1))_hi (value & ~(value - 1))_lo - // The above value is now a power of two in u64. Calling log2_floor on this - // value gives us the index we are looking for. - call tasm_arithmetic_u64_log_2_floor + // Return a pointer to a free address and allocate `size` words for this pointer + // Before: _ size + // After: _ *next_addr + tasm_memory_dyn_malloc: + push 1 // _ size *free_pointer+1 + read_mem 1 // _ size *next_addr' *free_pointer + swap 1 // _ size *free_pointer *next_addr' - return + // add 1 iff `next_addr` was 0, i.e. uninitialized. + dup 0 // _ size *free_pointer *next_addr' *next_addr' + push 0 // _ size *free_pointer *next_addr' *next_addr' 0 + eq // _ size *free_pointer *next_addr' (*next_addr' == 0) + add // _ size *free_pointer *next_addr + dup 0 // _ size *free_pointer *next_addr *next_addr + dup 3 // _ size *free_pointer *next_addr *next_addr size - // Return a pointer to a free address and allocate `size` words for this pointer - // Before: _ size - // After: _ *next_addr - tasm_memory_dyn_malloc: - push 0 // _ size *free_pointer - read_mem // _ size *free_pointer *next_addr' - - // add 1 iff `next_addr` was 0, i.e. uninitialized. - dup 0 // _ size *free_pointer *next_addr' *next_addr' - push 0 // _ size *free_pointer *next_addr' *next_addr' 0 - eq // _ size *free_pointer *next_addr' (*next_addr' == 0) - add // _ size *free_pointer *next_addr - - dup 0 // _ size *free_pointer *next_addr *next_addr - dup 3 // _ size *free_pointer *next_addr *next_addr size - - // Ensure that `size` does not exceed 2^32 - split - swap 1 - push 0 - eq - assert - - add // _ size *free_pointer *next_addr *(next_addr + size) - - // Ensure that no more than 2^32 words are allocated, because I don't want a wrap-around - // in the address space - split - swap 1 - push 0 - eq - assert - - swap 1 // _ size *free_pointer *(next_addr + size) *next_addr - swap 3 // _ *next_addr *free_pointer *(next_addr + size) size - pop 1 // _ *next_addr *free_pointer *(next_addr + size) - write_mem - pop 1 // _ next_addr - return + // Ensure that `size` does not exceed 2^32 + split + swap 1 + push 0 + eq + assert - // BEFORE: rhs_hi rhs_lo lhs_hi lhs_lo - // AFTER: (rhs ^ lhs)_hi (rhs ^ lhs)_lo - tasm_arithmetic_u64_xor: - swap 3 - xor - // stack: _ lhs_lo rhs_lo (lhs_hi ^ rhs_hi) + add // _ size *free_pointer *next_addr *(next_addr + size) - swap 2 - xor - // stack: _ (lhs_hi ^ rhs_hi) (rhs_lo ^ lhs_lo) + // Ensure that no more than 2^32 words are allocated, because I don't want a wrap-around + // in the address space + split + swap 1 + push 0 + eq + assert + + swap 1 // _ size *free_pointer *(next_addr + size) *next_addr + swap 3 // _ *next_addr *free_pointer *(next_addr + size) size + pop 1 // _ *next_addr *free_pointer *(next_addr + size) + swap 1 // _ *next_addr *(next_addr + size) *free_pointer + write_mem 1 // _ *next_addr *free_pointer+1 + pop 1 // _ *next_addr + return - return + // BEFORE: rhs_hi rhs_lo lhs_hi lhs_lo + // AFTER: (rhs ^ lhs)_hi (rhs ^ lhs)_lo + tasm_arithmetic_u64_xor: + swap 3 + xor + // stack: _ lhs_lo rhs_lo (lhs_hi ^ rhs_hi) + + swap 2 + xor + // stack: _ (lhs_hi ^ rhs_hi) (rhs_lo ^ lhs_lo) + + return ) } diff --git a/triton-vm/src/instruction.rs b/triton-vm/src/instruction.rs index 4cd1aa136..5aac6dbb3 100644 --- a/triton-vm/src/instruction.rs +++ b/triton-vm/src/instruction.rs @@ -117,8 +117,8 @@ pub enum AnInstruction { Assert, // Memory access - ReadMem, - WriteMem, + ReadMem(NumberOfWords), + WriteMem(NumberOfWords), // Hashing-related Hash, @@ -171,18 +171,18 @@ impl AnInstruction { Return => 16, Recurse => 24, Assert => 10, - ReadMem => 32, - WriteMem => 18, - Hash => 26, - DivineSibling => 40, - AssertVector => 34, - SpongeInit => 48, - SpongeAbsorb => 42, - SpongeSqueeze => 56, - Add => 50, - Mul => 58, - Invert => 64, - Eq => 66, + ReadMem(_) => 41, + WriteMem(_) => 11, + Hash => 18, + DivineSibling => 32, + AssertVector => 26, + SpongeInit => 40, + SpongeAbsorb => 34, + SpongeSqueeze => 48, + Add => 42, + Mul => 50, + Invert => 56, + Eq => 58, Split => 4, Lt => 6, And => 14, @@ -191,12 +191,12 @@ impl AnInstruction { Pow => 30, DivMod => 20, PopCount => 28, - XxAdd => 74, - XxMul => 82, - XInvert => 72, - XbMul => 90, - ReadIo(_) => 41, - WriteIo(_) => 11, + XxAdd => 66, + XxMul => 74, + XInvert => 64, + XbMul => 82, + ReadIo(_) => 49, + WriteIo(_) => 19, } } @@ -214,8 +214,8 @@ impl AnInstruction { Return => "return", Recurse => "recurse", Assert => "assert", - ReadMem => "read_mem", - WriteMem => "write_mem", + ReadMem(_) => "read_mem", + WriteMem(_) => "write_mem", Hash => "hash", DivineSibling => "divine_sibling", AssertVector => "assert_vector", @@ -253,6 +253,7 @@ impl AnInstruction { Divine(_) => 2, Dup(_) | Swap(_) => 2, Call(_) => 2, + ReadMem(_) | WriteMem(_) => 2, ReadIo(_) | WriteIo(_) => 2, _ => 1, } @@ -284,8 +285,8 @@ impl AnInstruction { Return => Return, Recurse => Recurse, Assert => Assert, - ReadMem => ReadMem, - WriteMem => WriteMem, + ReadMem(x) => ReadMem(*x), + WriteMem(x) => WriteMem(*x), Hash => Hash, DivineSibling => DivineSibling, AssertVector => AssertVector, @@ -339,8 +340,8 @@ impl AnInstruction { Return => 0, Recurse => 0, Assert => -1, - ReadMem => 1, - WriteMem => -1, + ReadMem(n) => n.num_words() as i32, + WriteMem(n) => -(n.num_words() as i32), Hash => -5, DivineSibling => 5, AssertVector => -5, @@ -389,6 +390,7 @@ impl Display for AnInstruction { Pop(arg) | Divine(arg) => write!(f, " {arg}"), Dup(arg) | Swap(arg) => write!(f, " {arg}"), Call(arg) => write!(f, " {arg}"), + ReadMem(arg) | WriteMem(arg) => write!(f, " {arg}"), ReadIo(arg) | WriteIo(arg) => write!(f, " {arg}"), _ => Ok(()), } @@ -402,6 +404,7 @@ impl Instruction { Push(arg) | Call(arg) => Some(*arg), Pop(arg) | Divine(arg) => Some(arg.into()), Dup(arg) | Swap(arg) => Some(arg.into()), + ReadMem(arg) | WriteMem(arg) => Some(arg.into()), ReadIo(arg) | WriteIo(arg) => Some(arg.into()), _ => None, } @@ -417,20 +420,15 @@ impl Instruction { /// if the argument is out of range. #[must_use] pub fn change_arg(&self, new_arg: BFieldElement) -> Option { - let instruction_with_infallible_substitution = match self { - Push(_) => Some(Push(new_arg)), - Call(_) => Some(Call(new_arg)), - _ => None, - }; - if instruction_with_infallible_substitution.is_some() { - return instruction_with_infallible_substitution; - } - let new_instruction = match self { Pop(_) => Some(Pop(new_arg.try_into().ok()?)), + Push(_) => Some(Push(new_arg)), Divine(_) => Some(Divine(new_arg.try_into().ok()?)), Dup(_) => Some(Dup(new_arg.value().try_into().ok()?)), Swap(_) => Some(Swap(new_arg.value().try_into().ok()?)), + Call(_) => Some(Call(new_arg)), + ReadMem(_) => Some(ReadMem(new_arg.try_into().ok()?)), + WriteMem(_) => Some(WriteMem(new_arg.try_into().ok()?)), ReadIo(_) => Some(ReadIo(new_arg.try_into().ok()?)), WriteIo(_) => Some(WriteIo(new_arg.try_into().ok()?)), _ => None, @@ -494,8 +492,8 @@ const fn all_instructions_without_args() -> [AnInstruction; Instr Return, Recurse, Assert, - ReadMem, - WriteMem, + ReadMem(N1), + WriteMem(N1), Hash, DivineSibling, AssertVector, diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 96388ca7d..ef98442f0 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -91,20 +91,20 @@ //! halt //! //! sum_of_squares_secret_in: -//! divine 1 dup 0 mul // s₁² -//! divine 1 dup 0 mul add // s₁²+s₂² -//! divine 1 dup 0 mul add // s₁²+s₂²+s₃² +//! divine 1 dup 0 mul // s₁² +//! divine 1 dup 0 mul add // s₁²+s₂² +//! divine 1 dup 0 mul add // s₁²+s₂²+s₃² //! return //! //! sum_of_squares_ram: -//! push 17 // 17 -//! read_mem // 17 s₄ -//! dup 0 mul // 17 s₄² -//! swap 1 pop 1 // s₄² -//! push 42 // s₄² 42 -//! read_mem // s₄² 42 s₅ -//! dup 0 mul // s₄² 42 s₅² -//! swap 1 pop 1 // s₄² s₅² +//! push 18 // 18 +//! read_mem 1 // s₄ 17 +//! pop 1 // s₄ +//! dup 0 mul // s₄² +//! push 43 // s₄² 43 +//! read_mem 1 // s₄² s₅ 42 +//! pop 1 // s₄² s₅ +//! dup 0 mul // s₄² s₅² //! add // s₄²+s₅² //! return //! ); @@ -341,6 +341,8 @@ macro_rules! triton_asm { [dup $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(dup $arg); $num ] }; [swap $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(swap $arg); $num ] }; [call $arg:ident; $num:expr] => { vec![ $crate::triton_instr!(call $arg); $num ] }; + [read_mem $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(read_mem $arg); $num ] }; + [write_mem $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(write_mem $arg); $num ] }; [read_io $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(read_io $arg); $num ] }; [write_io $arg:literal; $num:expr] => { vec![ $crate::triton_instr!(write_io $arg); $num ] }; [$instr:ident; $num:expr] => { vec![ $crate::triton_instr!($instr); $num ] }; @@ -400,6 +402,16 @@ macro_rules! triton_instr { let instruction = $crate::instruction::AnInstruction::::Call(argument); $crate::instruction::LabelledInstruction::Instruction(instruction) }}; + (read_mem $arg:literal) => {{ + let argument: $crate::op_stack::NumberOfWords = u32::try_into($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::ReadMem(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; + (write_mem $arg:literal) => {{ + let argument: $crate::op_stack::NumberOfWords = u32::try_into($arg).unwrap(); + let instruction = $crate::instruction::AnInstruction::::WriteMem(argument); + $crate::instruction::LabelledInstruction::Instruction(instruction) + }}; (read_io $arg:literal) => {{ let argument: $crate::op_stack::NumberOfWords = u32::try_into($arg).unwrap(); let instruction = $crate::instruction::AnInstruction::::ReadIo(argument); @@ -606,9 +618,9 @@ mod tests { #[test] fn lib_use_initial_ram() { let program = triton_program!( - push 51 read_mem - push 42 read_mem - swap 1 swap 2 mul + push 52 read_mem 1 pop 1 + push 43 read_mem 1 pop 1 + mul write_io 1 halt ); diff --git a/triton-vm/src/op_stack.rs b/triton-vm/src/op_stack.rs index 2f8355706..66636f0fc 100644 --- a/triton-vm/src/op_stack.rs +++ b/triton-vm/src/op_stack.rs @@ -9,6 +9,10 @@ use arbitrary::Arbitrary; use get_size::GetSize; use itertools::Itertools; use num_traits::Zero; +use rand::distributions::Distribution; +use rand::distributions::Standard; +use rand::seq::IteratorRandom; +use rand::Rng; use serde_derive::Deserialize; use serde_derive::Serialize; use strum::EnumCount; @@ -242,6 +246,7 @@ impl UnderflowIO { Deserialize, EnumCount, EnumIter, + Arbitrary, )] pub enum OpStackElement { #[default] @@ -286,6 +291,12 @@ impl OpStackElement { } } +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> OpStackElement { + OpStackElement::iter().choose(rng).unwrap() + } +} + impl Display for OpStackElement { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { let index = self.index(); @@ -413,6 +424,7 @@ impl From<&OpStackElement> for BFieldElement { Deserialize, EnumCount, EnumIter, + Arbitrary, )] pub enum NumberOfWords { #[default] @@ -448,10 +460,15 @@ impl NumberOfWords { } } +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> NumberOfWords { + NumberOfWords::iter().choose(rng).unwrap() + } +} + impl Display for NumberOfWords { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - let index = self.num_words(); - write!(f, "{index}") + write!(f, "{}", self.num_words()) } } @@ -780,6 +797,13 @@ mod tests { assert_eq!(computed_range, expected_range); } + #[test] + fn number_of_legal_number_of_words_corresponds_to_distinct_number_of_number_of_words() { + let legal_values = NumberOfWords::legal_values(); + let distinct_values = NumberOfWords::COUNT; + assert_eq!(distinct_values, legal_values.len()); + } + #[test] fn compute_illegal_values_of_number_of_words() { let _ = NumberOfWords::illegal_values(); diff --git a/triton-vm/src/parser.rs b/triton-vm/src/parser.rs index cb8f9bcac..82a7e4920 100644 --- a/triton-vm/src/parser.rs +++ b/triton-vm/src/parser.rs @@ -233,8 +233,8 @@ fn an_instruction(s: &str) -> ParseResult> { let control_flow = alt((nop, skiz, call, return_, recurse, halt)); // Memory access - let read_mem = instruction("read_mem", ReadMem); - let write_mem = instruction("write_mem", WriteMem); + let read_mem = read_mem_instruction(); + let write_mem = write_mem_instruction(); let memory_access = alt((read_mem, write_mem)); @@ -400,6 +400,26 @@ fn call_instruction<'a>() -> impl Fn(&'a str) -> ParseResult impl Fn(&str) -> ParseResult> { + move |s: &str| { + let (s, _) = token1("read_mem")(s)?; // require space after instruction name + let (s, arg) = number_of_words(s)?; + let (s, _) = comment_or_whitespace1(s)?; + + Ok((s, ReadMem(arg))) + } +} + +fn write_mem_instruction() -> impl Fn(&str) -> ParseResult> { + move |s: &str| { + let (s, _) = token1("write_mem")(s)?; // require space after instruction name + let (s, arg) = number_of_words(s)?; + let (s, _) = comment_or_whitespace1(s)?; + + Ok((s, WriteMem(arg))) + } +} + fn read_io_instruction() -> impl Fn(&str) -> ParseResult> { move |s: &str| { let (s, _) = token1("read_io")(s)?; // require space after instruction name @@ -569,6 +589,7 @@ pub(crate) mod tests { use rand::distributions::WeightedIndex; use rand::prelude::*; use rand::Rng; + use strum::EnumCount; use twenty_first::shared_math::digest::DIGEST_LENGTH; use LabelledInstruction::*; @@ -666,7 +687,17 @@ pub(crate) mod tests { let mut rng = thread_rng(); let difficult_instructions = vec![ - "pop", "push", "divine", "dup", "swap", "skiz", "call", "read_io", "write_io", + "pop", + "push", + "divine", + "dup", + "swap", + "skiz", + "call", + "read_mem", + "write_mem", + "read_io", + "write_io", ]; let simple_instructions = ALL_INSTRUCTION_NAMES .into_iter() @@ -675,7 +706,7 @@ pub(crate) mod tests { let generators = [vec!["simple"], difficult_instructions].concat(); // Test difficult instructions more frequently. - let weights = vec![simple_instructions.len(), 2, 2, 2, 6, 6, 2, 10, 2, 2]; + let weights = vec![simple_instructions.len(), 2, 2, 2, 6, 6, 2, 10, 2, 2, 2, 2]; assert_eq!( generators.len(), @@ -692,28 +723,28 @@ pub(crate) mod tests { } "pop" => { - let arg: usize = rng.gen_range(1..=5); + let arg: NumberOfWords = rng.gen(); vec!["pop".to_string(), format!("{arg}")] } "push" => { - let max: i128 = BFieldElement::MAX as i128; - let arg: i128 = rng.gen_range(-max..max); + let max = BFieldElement::MAX as i128; + let arg = rng.gen_range(-max..max); vec!["push".to_string(), format!("{arg}")] } "divine" => { - let arg: usize = rng.gen_range(1..=5); + let arg: NumberOfWords = rng.gen(); vec!["divine".to_string(), format!("{arg}")] } "dup" => { - let arg: usize = rng.gen_range(0..16); + let arg: OpStackElement = rng.gen(); vec!["dup".to_string(), format!("{arg}")] } "swap" => { - let arg: usize = rng.gen_range(1..16); + let arg: usize = rng.gen_range(1..OpStackElement::COUNT); vec!["swap".to_string(), format!("{arg}")] } @@ -728,13 +759,23 @@ pub(crate) mod tests { vec!["call".to_string(), some_label] } + "read_mem" => { + let arg: NumberOfWords = rng.gen(); + vec!["read_mem".to_string(), format!("{arg}")] + } + + "write_mem" => { + let arg: NumberOfWords = rng.gen(); + vec!["write_mem".to_string(), format!("{arg}")] + } + "read_io" => { - let arg: usize = rng.gen_range(1..=5); + let arg: NumberOfWords = rng.gen(); vec!["read_io".to_string(), format!("{arg}")] } "write_io" => { - let arg: usize = rng.gen_range(1..=5); + let arg: NumberOfWords = rng.gen(); vec!["write_io".to_string(), format!("{arg}")] } diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 4bf7da8b2..7dac3c6de 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -1119,7 +1119,6 @@ pub(crate) mod tests { use twenty_first::shared_math::other::random_elements; use crate::example_programs::*; - use crate::instruction::AnInstruction; use crate::instruction::Instruction; use crate::op_stack::OpStackElement; use crate::shared_tests::*; @@ -1220,27 +1219,27 @@ pub(crate) mod tests { #[test] fn print_ram_table_example_for_specification() { let program = triton_program!( - push 5 push 6 write_mem pop 1 - push 15 push 16 write_mem pop 1 - push 5 read_mem pop 2 - push 15 read_mem pop 2 - push 5 push 7 write_mem pop 1 - push 15 read_mem - push 5 read_mem + push 9 push 8 push 5 write_mem 2 pop 1 // write 8 to address 5, 9 to address 6 + push 18 push 15 write_mem 1 pop 1 // write 18 to address 15 + push 6 read_mem 1 pop 2 // read from address 5 + push 16 read_mem 1 pop 2 // read from address 15 + push 7 push 5 write_mem 1 pop 1 // write 7 to address 5 + push 16 read_mem 1 // _ 18 15 + push 6 read_mem 1 // _ 18 15 7 5 halt ); - let (_, _, master_base_table) = - master_base_table_for_low_security_level(ProgramAndInput::without_input(program)); + let (_, _, master_base_table, _, _) = + master_tables_for_low_security_level(ProgramAndInput::without_input(program)); println!(); println!("Processor Table:"); println!( - "| clk | pi | ci | nia | st0 \ - | st1 | st2 | st3 | ramp | ramv |" + "| clk | ci | nia \ + | st0 | st1 | st2 | st3 |" ); println!( - "|-----------:|:-----------|:-----------|:-----------|-----------:\ - |-----------:|-----------:|-----------:|-----------:|-----------:|" + "|-----------:|:-----------|:-----------\ + |-----------:|-----------:|-----------:|-----------:|" ); for row in master_base_table.table(ProcessorTable).rows() { let clk = row[ProcessorBaseTableColumn::CLK.base_table_index()].to_string(); @@ -1248,18 +1247,10 @@ pub(crate) mod tests { let st1 = row[ProcessorBaseTableColumn::ST1.base_table_index()].to_string(); let st2 = row[ProcessorBaseTableColumn::ST2.base_table_index()].to_string(); let st3 = row[ProcessorBaseTableColumn::ST3.base_table_index()].to_string(); - let ramp = row[ProcessorBaseTableColumn::RAMP.base_table_index()].to_string(); - let ramv = row[ProcessorBaseTableColumn::RAMV.base_table_index()].to_string(); - - let prev_instruction = - row[ProcessorBaseTableColumn::PreviousInstruction.base_table_index()].value(); - let pi = match Instruction::try_from(prev_instruction) { - Ok(AnInstruction::Halt) | Err(_) => "-".to_string(), - Ok(instr) => instr.name().to_string(), - }; + let (ci, nia) = ci_and_nia_from_master_table_row(row); - let interesting_cols = [clk, pi, ci, nia, st0, st1, st2, st3, ramp, ramv]; + let interesting_cols = [clk, ci, nia, st0, st1, st2, st3]; let interesting_cols = interesting_cols .iter() .map(|ff| format!("{:>10}", format!("{ff}"))) @@ -1269,23 +1260,25 @@ pub(crate) mod tests { } println!(); println!("RAM Table:"); - println!("| clk | pi | ramp | ramv | iord |"); + println!("| clk | type | pointer | value | iord |"); println!("|-----------:|:-----------|-----------:|-----------:|-----:|"); for row in master_base_table.table(TableId::RamTable).rows() { let clk = row[RamBaseTableColumn::CLK.base_table_index()].to_string(); - let ramp = row[RamBaseTableColumn::RAMP.base_table_index()].to_string(); - let ramv = row[RamBaseTableColumn::RAMV.base_table_index()].to_string(); + let ramp = row[RamBaseTableColumn::RamPointer.base_table_index()].to_string(); + let ramv = row[RamBaseTableColumn::RamValue.base_table_index()].to_string(); let iord = row[RamBaseTableColumn::InverseOfRampDifference.base_table_index()].to_string(); - let prev_instruction = - row[RamBaseTableColumn::PreviousInstruction.base_table_index()].value(); - let pi = match Instruction::try_from(prev_instruction) { - Ok(AnInstruction::Halt) | Err(_) => "-".to_string(), - Ok(instr) => instr.name().to_string(), - }; + let instruction_type = + match row[RamBaseTableColumn::InstructionType.base_table_index()] { + ram_table::INSTRUCTION_TYPE_READ => "read", + ram_table::INSTRUCTION_TYPE_WRITE => "write", + ram_table::PADDING_INDICATOR => "pad", + _ => "-", + } + .to_string(); - let interesting_cols = [clk, pi, ramp, ramv, iord]; + let interesting_cols = [clk, instruction_type, ramp, ramv, iord]; let interesting_cols = interesting_cols .iter() .map(|ff| format!("{:>10}", format!("{ff}"))) diff --git a/triton-vm/src/table/challenges.rs b/triton-vm/src/table/challenges.rs index a1ea46951..624a5b20f 100644 --- a/triton-vm/src/table/challenges.rs +++ b/triton-vm/src/table/challenges.rs @@ -103,9 +103,9 @@ pub enum ChallengeId { OpStackFirstUnderflowElementWeight, RamClkWeight, - RamRampWeight, - RamRamvWeight, - RamPreviousInstructionWeight, + RamPointerWeight, + RamValueWeight, + RamInstructionTypeWeight, JumpStackClkWeight, JumpStackCiWeight, diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index e5994401b..aff602d5c 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -330,6 +330,7 @@ pub struct MasterBaseTable { program_table_len: usize, main_execution_len: usize, op_stack_table_len: usize, + ram_table_len: usize, hash_coprocessor_execution_len: usize, cascade_table_len: usize, u32_coprocesor_execution_len: usize, @@ -591,6 +592,7 @@ impl MasterBaseTable { program_table_len: aet.program_table_length(), main_execution_len: aet.processor_table_length(), op_stack_table_len: aet.op_stack_table_length(), + ram_table_len: aet.ram_table_length(), hash_coprocessor_execution_len: aet.hash_table_length(), cascade_table_len: aet.cascade_table_length(), u32_coprocesor_execution_len: aet.u32_table_length(), @@ -725,14 +727,13 @@ impl MasterBaseTable { fn all_table_lengths(&self) -> [usize; NUM_TABLES_WITHOUT_DEGREE_LOWERING] { let processor_table_len = self.main_execution_len; - let ram_table_len = self.main_execution_len; let jump_stack_table_len = self.main_execution_len; [ self.program_table_len, processor_table_len, self.op_stack_table_len, - ram_table_len, + self.ram_table_len, jump_stack_table_len, self.hash_coprocessor_execution_len, self.cascade_table_len, diff --git a/triton-vm/src/table/op_stack_table.rs b/triton-vm/src/table/op_stack_table.rs index feab5663e..349c5b312 100644 --- a/triton-vm/src/table/op_stack_table.rs +++ b/triton-vm/src/table/op_stack_table.rs @@ -248,7 +248,7 @@ impl ExtOpStackTable { let log_derivative_remains_or_stack_pointer_doesnt_change = log_derivative_remains.clone() * (stack_pointer_next.clone() - stack_pointer.clone()); let log_derivatve_remains_or_next_row_is_not_padding_row = - log_derivative_remains.clone() * next_row_is_not_padding_row; + log_derivative_remains * next_row_is_not_padding_row; let log_derivative_updates_correctly = log_derivative_accumulates_or_stack_pointer_changes_or_next_row_is_padding_row @@ -324,18 +324,18 @@ impl OpStackTable { pub fn pad_trace(mut op_stack_table: ArrayViewMut2, op_stack_table_len: usize) { let last_row_index = op_stack_table_len.saturating_sub(1); - let mut last_row = op_stack_table.row(last_row_index).to_owned(); - last_row[IB1ShrinkStack.base_table_index()] = PADDING_VALUE; + let mut padding_row = op_stack_table.row(last_row_index).to_owned(); + padding_row[IB1ShrinkStack.base_table_index()] = PADDING_VALUE; if op_stack_table_len == 0 { let first_stack_pointer = u32::try_from(OpStackElement::COUNT).unwrap().into(); - last_row[StackPointer.base_table_index()] = first_stack_pointer; + padding_row[StackPointer.base_table_index()] = first_stack_pointer; } let mut padding_section = op_stack_table.slice_mut(s![op_stack_table_len.., ..]); padding_section .axis_iter_mut(Axis(0)) .into_par_iter() - .for_each(|mut row| row.assign(&last_row)); + .for_each(|mut row| row.assign(&padding_row)); } pub fn extend( diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index ec56e6b4c..2b119787f 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -29,6 +29,7 @@ use crate::table::constraint_circuit::DualRowIndicator::*; use crate::table::constraint_circuit::SingleRowIndicator::*; use crate::table::constraint_circuit::*; use crate::table::cross_table_argument::*; +use crate::table::ram_table; use crate::table::table_column::ProcessorBaseTableColumn::*; use crate::table::table_column::ProcessorExtTableColumn::*; use crate::table::table_column::*; @@ -90,15 +91,14 @@ impl ProcessorTable { processor_table.slice_mut(s![processor_table_len.., CLK.base_table_index()]), ); - // The memory-like tables “RAM” and “Jump Stack” do not have a padding indicator. Hence, - // clock jump differences are being looked up in their padding sections. The clock jump - // differences in that section are always 1. The lookup multiplicities of clock value 1 must - // be increased accordingly: one per padding row, times the number of memory-like tables - // without padding indicator, which is 2. - let num_padding_rows = 2 * (processor_table.nrows() - processor_table_len); - let num_pad_rows = BFieldElement::new(num_padding_rows as u64); + // The Jump Stack Table does not have a padding indicator. Hence, clock jump differences are + // being looked up in its padding sections. The clock jump differences in that section are + // always 1. The lookup multiplicities of clock value 1 must be increased accordingly: one + // per padding row. + let num_padding_rows = processor_table.nrows() - processor_table_len; + let num_padding_rows = BFieldElement::new(num_padding_rows as u64); let mut row_1 = processor_table.row_mut(1); - row_1[ClockJumpDifferenceLookupMultiplicity.base_table_index()] += num_pad_rows; + row_1[ClockJumpDifferenceLookupMultiplicity.base_table_index()] += num_padding_rows; } pub fn extend( @@ -167,19 +167,14 @@ impl ProcessorTable { challenges, ); - // RAM Table - let clk = current_row[CLK.base_table_index()]; - let ramv = current_row[RAMV.base_table_index()]; - let ramp = current_row[RAMP.base_table_index()]; - let previous_instruction = current_row[PreviousInstruction.base_table_index()]; - let compressed_row_for_ram_table_permutation_argument = clk * challenges[RamClkWeight] - + ramp * challenges[RamRampWeight] - + ramv * challenges[RamRamvWeight] - + previous_instruction * challenges[RamPreviousInstructionWeight]; - ram_table_running_product *= - challenges[RamIndeterminate] - compressed_row_for_ram_table_permutation_argument; + if let Some(factor) = + Self::factor_for_ram_table_running_product(previous_row, current_row, challenges) + { + ram_table_running_product *= factor; + }; // JumpStack Table + let clk = current_row[CLK.base_table_index()]; let ci = current_row[CI.base_table_index()]; let jsp = current_row[JSP.base_table_index()]; let jso = current_row[JSO.base_table_index()]; @@ -398,6 +393,56 @@ impl ProcessorTable { factor } + fn factor_for_ram_table_running_product( + maybe_previous_row: Option>, + current_row: ArrayView1, + challenges: &Challenges, + ) -> Option { + let is_padding_row = current_row[IsPadding.base_table_index()].is_one(); + if is_padding_row { + return None; + } + + let previous_row = maybe_previous_row?; + let previous_instruction = Self::instruction_from_row(previous_row)?; + + let instruction_type = match previous_instruction { + ReadMem(_) => ram_table::INSTRUCTION_TYPE_READ, + WriteMem(_) => ram_table::INSTRUCTION_TYPE_WRITE, + _ => return None, + }; + + // longer stack means relevant information is on top of stack, i.e., in stack registers + let row_with_longer_stack = match previous_instruction { + ReadMem(_) => current_row.view(), + WriteMem(_) => previous_row.view(), + _ => unreachable!(), + }; + let op_stack_delta = previous_instruction + .op_stack_size_influence() + .unsigned_abs() as usize; + + let mut factor = XFieldElement::one(); + for ram_pointer_offset in 0..op_stack_delta { + let num_ram_pointers = 1; + let ram_value_index = ram_pointer_offset + num_ram_pointers; + let ram_value_column = Self::op_stack_column_by_index(ram_value_index); + let ram_value = row_with_longer_stack[ram_value_column.base_table_index()]; + + let ram_pointer = row_with_longer_stack[ST0.base_table_index()]; + let offset = BFieldElement::new(ram_pointer_offset as u64); + let offset_ram_pointer = ram_pointer + offset; + + let clk = previous_row[CLK.base_table_index()]; + let compressed_row = clk * challenges[RamClkWeight] + + instruction_type * challenges[RamInstructionTypeWeight] + + offset_ram_pointer * challenges[RamPointerWeight] + + ram_value * challenges[RamValueWeight]; + factor *= challenges[RamIndeterminate] - compressed_row; + } + Some(factor) + } + fn instruction_from_row(row: ArrayView1) -> Option { let opcode = row[CI.base_table_index()]; let instruction: Instruction = opcode.try_into().ok()?; @@ -467,8 +512,6 @@ impl ExtProcessorTable { let st9_is_0 = base_row(ST9); let st10_is_0 = base_row(ST10); let op_stack_pointer_is_16 = base_row(OpStackPointer) - constant(16); - let ramp_is_0 = base_row(RAMP); - let previous_instruction_is_0 = base_row(PreviousInstruction); // Compress the program digest using an Evaluation Argument. // Lowest index in the digest corresponds to lowest index on the stack. @@ -514,13 +557,8 @@ impl ExtProcessorTable { ext_row(OpStackTablePermArg) - x_constant(PermArg::default_initial()); // ram table - let ram_indeterminate = challenge(RamIndeterminate); - let ram_ramv_weight = challenge(RamRamvWeight); - // note: `clk`, and `ramp` are already constrained to be 0. - let compressed_row_for_ram_table = ram_ramv_weight * base_row(RAMV); - let running_product_for_ram_table_is_initialized_correctly = ext_row(RamTablePermArg) - - x_constant(PermArg::default_initial()) - * (ram_indeterminate - compressed_row_for_ram_table); + let running_product_for_ram_table_is_initialized_correctly = + ext_row(RamTablePermArg) - x_constant(PermArg::default_initial()); // jump-stack table let jump_stack_indeterminate = challenge(JumpStackIndeterminate); @@ -585,8 +623,6 @@ impl ExtProcessorTable { st10_is_0, compressed_program_digest_is_expected_program_digest, op_stack_pointer_is_16, - ramp_is_0, - previous_instruction_is_0, running_evaluation_for_standard_input_is_initialized_correctly, instruction_lookup_log_derivative_is_initialized_correctly, running_evaluation_for_standard_output_is_initialized_correctly, @@ -831,17 +867,14 @@ impl ExtProcessorTable { fn instruction_group_keep_ram( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) }; - vec![ - next_base_row(RAMV) - curr_base_row(RAMV), - next_base_row(RAMP) - curr_base_row(RAMP), - ] + vec![next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg)] } fn instruction_group_no_io( @@ -1616,24 +1649,11 @@ impl ExtProcessorTable { fn instruction_read_mem( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // the RAM pointer is overwritten with st0 - let update_ramp = next_base_row(RAMP) - curr_base_row(ST0); - - // The top of the stack is overwritten with the RAM value. - let st0_becomes_ramv = next_base_row(ST0) - next_base_row(RAMV); - - let specific_constraints = vec![update_ramp, st0_becomes_ramv]; [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_grow_op_stack(circuit_builder), + Self::instruction_group_step_2(circuit_builder), + Self::instruction_group_decompose_arg(circuit_builder), + Self::read_from_ram_any_of(circuit_builder, &NumberOfWords::legal_values()), + Self::prohibit_any_illegal_number_of_words(circuit_builder), Self::instruction_group_no_io(circuit_builder), ] .concat() @@ -1642,24 +1662,11 @@ impl ExtProcessorTable { fn instruction_write_mem( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let curr_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) - }; - let next_base_row = |col: ProcessorBaseTableColumn| { - circuit_builder.input(NextBaseRow(col.master_base_table_index())) - }; - - // the RAM pointer is overwritten with st1 - let update_ramp = next_base_row(RAMP) - curr_base_row(ST1); - - // The RAM value is overwritten with the top of the stack. - let ramv_becomes_st0 = next_base_row(RAMV) - curr_base_row(ST0); - - let specific_constraints = vec![update_ramp, ramv_becomes_st0]; [ - specific_constraints, - Self::instruction_group_step_1(circuit_builder), - Self::instruction_group_shrink_op_stack(circuit_builder), + Self::instruction_group_step_2(circuit_builder), + Self::instruction_group_decompose_arg(circuit_builder), + Self::write_to_ram_any_of(circuit_builder, &NumberOfWords::legal_values()), + Self::prohibit_any_illegal_number_of_words(circuit_builder), Self::instruction_group_no_io(circuit_builder), ] .concat() @@ -2294,8 +2301,8 @@ impl ExtProcessorTable { Return => ExtProcessorTable::instruction_return(circuit_builder), Recurse => ExtProcessorTable::instruction_recurse(circuit_builder), Assert => ExtProcessorTable::instruction_assert(circuit_builder), - ReadMem => ExtProcessorTable::instruction_read_mem(circuit_builder), - WriteMem => ExtProcessorTable::instruction_write_mem(circuit_builder), + ReadMem(_) => ExtProcessorTable::instruction_read_mem(circuit_builder), + WriteMem(_) => ExtProcessorTable::instruction_write_mem(circuit_builder), Hash => ExtProcessorTable::instruction_hash(circuit_builder), DivineSibling => ExtProcessorTable::instruction_divine_sibling(circuit_builder), AssertVector => ExtProcessorTable::instruction_assert_vector(circuit_builder), @@ -2765,27 +2772,208 @@ impl ExtProcessorTable { challenge(OpStackIndeterminate) - compressed_row } - fn running_product_for_ram_table_updates_correctly( + fn write_to_ram_any_of( circuit_builder: &ConstraintCircuitBuilder, - ) -> ConstraintCircuitMonad { - let challenge = |c: ChallengeId| circuit_builder.challenge(c); + number_of_words: &[usize], + ) -> Vec> { + let all_constraint_groups = number_of_words + .iter() + .map(|&n| { + Self::conditional_constraints_for_writing_n_elements_to_ram(circuit_builder, n) + }) + .collect_vec(); + Self::combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraint_groups) + } + + fn read_from_ram_any_of( + circuit_builder: &ConstraintCircuitBuilder, + number_of_words: &[usize], + ) -> Vec> { + let all_constraint_groups = number_of_words + .iter() + .map(|&n| { + Self::conditional_constraints_for_reading_n_elements_from_ram(circuit_builder, n) + }) + .collect_vec(); + Self::combine_mutually_exclusive_constraint_groups(circuit_builder, all_constraint_groups) + } + + fn conditional_constraints_for_writing_n_elements_to_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, + ) -> Vec> { + Self::shrink_stack_by_n_and_write_n_elements_to_ram(circuit_builder, n) + .into_iter() + .map(|constraint| Self::indicator_polynomial(circuit_builder, n) * constraint) + .collect() + } + + fn conditional_constraints_for_reading_n_elements_from_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, + ) -> Vec> { + Self::grow_stack_by_n_and_read_n_elements_from_ram(circuit_builder, n) + .into_iter() + .map(|constraint| Self::indicator_polynomial(circuit_builder, n) * constraint) + .collect() + } + + fn shrink_stack_by_n_and_write_n_elements_to_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, + ) -> Vec> { + let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap().into()); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; let next_base_row = |col: ProcessorBaseTableColumn| { circuit_builder.input(NextBaseRow(col.master_base_table_index())) }; + + let op_stack_pointer_shrinks_by_n = + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) + constant(n); + let ram_pointer_grows_by_n = next_base_row(ST0) - curr_base_row(ST0) - constant(n); + + let mut constraints = vec![ + op_stack_pointer_shrinks_by_n, + ram_pointer_grows_by_n, + Self::running_product_op_stack_accounts_for_shrinking_stack_by(circuit_builder, n), + Self::running_product_ram_accounts_for_writing_n_elements(circuit_builder, n), + ]; + + let num_ram_pointers = 1; + for i in n + num_ram_pointers..OpStackElement::COUNT { + let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); + let next_stack_element = ProcessorTable::op_stack_column_by_index(i - n); + let element_i_is_shifted_by_n = + next_base_row(next_stack_element) - curr_base_row(curr_stack_element); + constraints.push(element_i_is_shifted_by_n); + } + constraints + } + + fn grow_stack_by_n_and_read_n_elements_from_ram( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, + ) -> Vec> { + let constant = |c: usize| circuit_builder.b_constant(u32::try_from(c).unwrap().into()); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let next_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(NextBaseRow(col.master_base_table_index())) + }; + + let op_stack_pointer_grows_by_n = + next_base_row(OpStackPointer) - curr_base_row(OpStackPointer) - constant(n); + let ram_pointer_shrinks_by_n = next_base_row(ST0) - curr_base_row(ST0) + constant(n); + + let mut constraints = vec![ + op_stack_pointer_grows_by_n, + ram_pointer_shrinks_by_n, + Self::running_product_op_stack_accounts_for_growing_stack_by(circuit_builder, n), + Self::running_product_ram_accounts_for_reading_n_elements(circuit_builder, n), + ]; + + let num_ram_pointers = 1; + for i in num_ram_pointers..OpStackElement::COUNT - n { + let curr_stack_element = ProcessorTable::op_stack_column_by_index(i); + let next_stack_element = ProcessorTable::op_stack_column_by_index(i + n); + let element_i_is_shifted_by_n = + next_base_row(next_stack_element) - curr_base_row(curr_stack_element); + constraints.push(element_i_is_shifted_by_n); + } + constraints + } + + fn running_product_ram_accounts_for_writing_n_elements( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, + ) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c.into()); let curr_ext_row = |col: ProcessorExtTableColumn| { circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) }; let next_ext_row = |col: ProcessorExtTableColumn| { circuit_builder.input(NextExtRow(col.master_ext_table_index())) }; + let single_write_factor = |ram_pointer_offset| { + Self::single_factor_for_permutation_argument_with_ram_table( + circuit_builder, + CurrentBaseRow, + ram_table::INSTRUCTION_TYPE_WRITE, + ram_pointer_offset, + ) + }; + + let mut factor = constant(1); + for ram_pointer_offset in 0..n { + factor = factor * single_write_factor(ram_pointer_offset); + } + + next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor + } + + fn running_product_ram_accounts_for_reading_n_elements( + circuit_builder: &ConstraintCircuitBuilder, + n: usize, + ) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c.into()); + let curr_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(CurrentExtRow(col.master_ext_table_index())) + }; + let next_ext_row = |col: ProcessorExtTableColumn| { + circuit_builder.input(NextExtRow(col.master_ext_table_index())) + }; + let single_read_factor = |ram_pointer_offset| { + Self::single_factor_for_permutation_argument_with_ram_table( + circuit_builder, + NextBaseRow, + ram_table::INSTRUCTION_TYPE_READ, + ram_pointer_offset, + ) + }; - let compressed_row = challenge(RamClkWeight) * next_base_row(CLK) - + challenge(RamRampWeight) * next_base_row(RAMP) - + challenge(RamRamvWeight) * next_base_row(RAMV) - + challenge(RamPreviousInstructionWeight) * next_base_row(PreviousInstruction); + let mut factor = constant(1); + for ram_pointer_offset in 0..n { + factor = factor * single_read_factor(ram_pointer_offset); + } - next_ext_row(RamTablePermArg) - - curr_ext_row(RamTablePermArg) * (challenge(RamIndeterminate) - compressed_row) + next_ext_row(RamTablePermArg) - curr_ext_row(RamTablePermArg) * factor + } + + fn single_factor_for_permutation_argument_with_ram_table( + circuit_builder: &ConstraintCircuitBuilder, + row_with_longer_stack_indicator: fn(usize) -> DualRowIndicator, + instruction_type: BFieldElement, + ram_pointer_offset: usize, + ) -> ConstraintCircuitMonad { + let constant = |c: u32| circuit_builder.b_constant(c.into()); + let b_constant = |c| circuit_builder.b_constant(c); + let challenge = |c: ChallengeId| circuit_builder.challenge(c); + let curr_base_row = |col: ProcessorBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(col.master_base_table_index())) + }; + let row_with_longer_stack = |col: ProcessorBaseTableColumn| { + circuit_builder.input(row_with_longer_stack_indicator( + col.master_base_table_index(), + )) + }; + + let num_ram_pointers = 1; + let ram_value_index = ram_pointer_offset + num_ram_pointers; + let ram_value_column = ProcessorTable::op_stack_column_by_index(ram_value_index); + let ram_value = row_with_longer_stack(ram_value_column); + + let ram_pointer = row_with_longer_stack(ST0); + let offset = constant(ram_pointer_offset as u32); + let offset_ram_pointer = ram_pointer + offset; + + let compressed_row = curr_base_row(CLK) * challenge(RamClkWeight) + + b_constant(instruction_type) * challenge(RamInstructionTypeWeight) + + offset_ram_pointer * challenge(RamPointerWeight) + + ram_value * challenge(RamValueWeight); + challenge(RamIndeterminate) - compressed_row } fn running_product_for_jump_stack_table_updates_correctly( @@ -3100,15 +3288,9 @@ impl ExtProcessorTable { let clk_increases_by_1 = next_base_row(CLK) - curr_base_row(CLK) - constant(1); let is_padding_is_0_or_does_not_change = curr_base_row(IsPadding) * (next_base_row(IsPadding) - curr_base_row(IsPadding)); - let previous_instruction_is_copied_correctly = (next_base_row(PreviousInstruction) - - curr_base_row(CI)) - * (constant(1) - next_base_row(IsPadding)); - - let instruction_independent_constraints = vec![ - clk_increases_by_1, - is_padding_is_0_or_does_not_change, - previous_instruction_is_copied_correctly, - ]; + + let instruction_independent_constraints = + vec![clk_increases_by_1, is_padding_is_0_or_does_not_change]; // instruction-specific constraints let all_transition_constraints_by_instruction = ALL_INSTRUCTIONS.map(|instruction| { @@ -3137,7 +3319,6 @@ impl ExtProcessorTable { let table_linking_constraints = vec![ Self::log_derivative_accumulates_clk_next(circuit_builder), Self::log_derivative_for_instruction_lookup_updates_correctly(circuit_builder), - Self::running_product_for_ram_table_updates_correctly(circuit_builder), Self::running_product_for_jump_stack_table_updates_correctly(circuit_builder), Self::running_evaluation_hash_input_updates_correctly(circuit_builder), Self::running_evaluation_hash_digest_updates_correctly(circuit_builder), @@ -3227,14 +3408,6 @@ impl<'a> Display for ProcessorTraceRow<'a> { self.row[JSO.base_table_index()].value(), self.row[JSD.base_table_index()].value(), )?; - row( - f, - format!( - "ramp: {:>width$} │ ramv: {:>width$} ╵", - self.row[RAMP.base_table_index()].value(), - self.row[RAMV.base_table_index()].value(), - ), - )?; row( f, format!( @@ -3659,24 +3832,44 @@ pub(crate) mod tests { #[test] fn transition_constraints_for_instruction_read_mem() { - let programs = [triton_program!(push 5 push 3 write_mem read_mem halt)]; - let test_rows = programs.map(|program| test_row_from_program(program, 3)); + let programs = [ + triton_program!(push 1 read_mem 1 push 0 eq assert assert halt), + triton_program!(push 2 read_mem 2 push 0 eq assert swap 1 push 2 eq assert halt), + triton_program!(push 3 read_mem 3 push 0 eq assert swap 2 push 3 eq assert halt), + triton_program!(push 4 read_mem 4 push 0 eq assert swap 3 push 4 eq assert halt), + triton_program!(push 5 read_mem 5 push 0 eq assert swap 4 push 5 eq assert halt), + ]; + let initial_ram = (0..5).map(|i| (i, i + 1)).collect(); + let non_determinism = NonDeterminism::default().with_ram(initial_ram); + let programs_with_input = programs.map(|program| ProgramAndInput { + program, + public_input: vec![], + non_determinism: non_determinism.clone(), + }); + let test_rows = programs_with_input.map(|p_w_i| test_row_from_program_with_input(p_w_i, 1)); let debug_info = TestRowsDebugInfo { - instruction: ReadMem, - debug_cols_curr_row: vec![ST0, ST1, RAMP, RAMV], - debug_cols_next_row: vec![ST0, ST1, RAMP, RAMV], + instruction: ReadMem(N1), + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0, ST1], }; assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); } #[test] fn transition_constraints_for_instruction_write_mem() { - let programs = [triton_program!(push 5 push 3 write_mem read_mem halt)]; - let test_rows = programs.map(|program| test_row_from_program(program, 2)); + let push_10_elements = triton_asm![push 2; 10]; + let programs = [ + triton_program!({&push_10_elements} write_mem 1 halt), + triton_program!({&push_10_elements} write_mem 2 halt), + triton_program!({&push_10_elements} write_mem 3 halt), + triton_program!({&push_10_elements} write_mem 4 halt), + triton_program!({&push_10_elements} write_mem 5 halt), + ]; + let test_rows = programs.map(|program| test_row_from_program(program, 10)); let debug_info = TestRowsDebugInfo { - instruction: WriteMem, - debug_cols_curr_row: vec![ST0, ST1, RAMP, RAMV], - debug_cols_next_row: vec![ST0, ST1, RAMP, RAMV], + instruction: WriteMem(N1), + debug_cols_curr_row: vec![ST0, ST1], + debug_cols_next_row: vec![ST0, ST1], }; assert_constraints_for_rows_with_debug_info(&test_rows, debug_info); } @@ -4342,4 +4535,24 @@ pub(crate) mod tests { &challenges, ); } + + #[proptest] + fn constructing_factor_for_ram_table_running_product_never_panics( + has_previous_row: bool, + #[strategy(vec(arb(), BASE_WIDTH))] previous_row: Vec, + #[strategy(vec(arb(), BASE_WIDTH))] current_row: Vec, + #[strategy(arb())] challenges: Challenges, + ) { + let previous_row = Array1::from(previous_row); + let current_row = Array1::from(current_row); + let maybe_previous_row = match has_previous_row { + true => Some(previous_row.view()), + false => None, + }; + let _ = ProcessorTable::factor_for_ram_table_running_product( + maybe_previous_row, + current_row.view(), + &challenges, + ); + } } diff --git a/triton-vm/src/table/ram_table.rs b/triton-vm/src/table/ram_table.rs index 413d37764..da755f635 100644 --- a/triton-vm/src/table/ram_table.rs +++ b/triton-vm/src/table/ram_table.rs @@ -1,8 +1,11 @@ -use std::collections::HashMap; +use std::cmp::Ordering; +use arbitrary::Arbitrary; +use itertools::Itertools; use ndarray::parallel::prelude::*; use ndarray::s; use ndarray::Array1; +use ndarray::Array2; use ndarray::ArrayView1; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; @@ -11,12 +14,13 @@ use num_traits::One; use num_traits::Zero; use strum::EnumCount; use twenty_first::shared_math::b_field_element::BFieldElement; +use twenty_first::shared_math::b_field_element::BFIELD_ONE; +use twenty_first::shared_math::b_field_element::BFIELD_ZERO; use twenty_first::shared_math::polynomial::Polynomial; use twenty_first::shared_math::traits::Inverse; use twenty_first::shared_math::x_field_element::XFieldElement; use crate::aet::AlgebraicExecutionTrace; -use crate::instruction::Instruction; use crate::table::challenges::ChallengeId::*; use crate::table::challenges::Challenges; use crate::table::constraint_circuit::DualRowIndicator::*; @@ -31,6 +35,42 @@ pub const BASE_WIDTH: usize = RamBaseTableColumn::COUNT; pub const EXT_WIDTH: usize = RamExtTableColumn::COUNT; pub const FULL_WIDTH: usize = BASE_WIDTH + EXT_WIDTH; +pub(crate) const INSTRUCTION_TYPE_WRITE: BFieldElement = BFIELD_ZERO; +pub(crate) const INSTRUCTION_TYPE_READ: BFieldElement = BFIELD_ONE; +pub(crate) const PADDING_INDICATOR: BFieldElement = BFieldElement::new(2); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Arbitrary)] +pub struct RamTableCall { + pub clk: u32, + pub ram_pointer: BFieldElement, + pub is_write: bool, + pub values: Vec, +} + +impl RamTableCall { + pub fn to_table_rows(self) -> Array2 { + let instruction_type = match self.is_write { + true => INSTRUCTION_TYPE_WRITE, + false => INSTRUCTION_TYPE_READ, + }; + let num_values = self.values.len(); + let pointers = (0..num_values) + .map(|offset| self.ram_pointer + BFieldElement::from(offset as u32)) + .collect::>(); + let values = Array1::from(self.values); + + let mut rows = Array2::zeros((num_values, BASE_WIDTH)); + rows.column_mut(CLK.base_table_index()) + .fill(self.clk.into()); + rows.column_mut(InstructionType.base_table_index()) + .fill(instruction_type); + rows.column_mut(RamPointer.base_table_index()) + .assign(&pointers); + rows.column_mut(RamValue.base_table_index()).assign(&values); + rows + } +} + #[derive(Debug, Clone)] pub struct RamTable {} @@ -43,177 +83,130 @@ impl RamTable { ram_table: &mut ArrayViewMut2, aet: &AlgebraicExecutionTrace, ) -> Vec { - // Store the registers relevant for the Ram Table, i.e., CLK, RAMP, RAMV, and - // PreviousInstruction, with RAMP as the key. Preserves, thus allows reusing, the order - // of the processor's rows, which are sorted by CLK. Note that the Ram Table does not - // have to be sorted by RAMP, but must form contiguous regions of RAMP values. - let mut pre_processed_ram_table: HashMap<_, Vec<_>> = HashMap::new(); - for processor_row in aet.processor_trace.rows() { - let clk = processor_row[ProcessorBaseTableColumn::CLK.base_table_index()]; - let ramp = processor_row[ProcessorBaseTableColumn::RAMP.base_table_index()]; - let ramv = processor_row[ProcessorBaseTableColumn::RAMV.base_table_index()]; - let previous_instruction = - processor_row[ProcessorBaseTableColumn::PreviousInstruction.base_table_index()]; - let ram_row = (clk, previous_instruction, ramv); - pre_processed_ram_table - .entry(ramp) - .and_modify(|v| v.push(ram_row)) - .or_insert_with(|| vec![ram_row]); + let mut ram_table = ram_table.slice_mut(s![0..aet.ram_table_length(), ..]); + let trace_iter = aet.ram_trace.rows().into_iter(); + + let sorted_rows = + trace_iter.sorted_by(|row_0, row_1| Self::compare_rows(row_0.view(), row_1.view())); + for (row_index, row) in sorted_rows.enumerate() { + ram_table.row_mut(row_index).assign(&row); + } + + let (bezout_0, bezout_1) = + Self::bezout_coefficient_polynomials_coefficients(ram_table.view()); + + Self::make_ram_table_consistent(&mut ram_table, bezout_0, bezout_1) + } + + fn compare_rows( + row_0: ArrayView1, + row_1: ArrayView1, + ) -> Ordering { + let ram_pointer_0 = row_0[RamPointer.base_table_index()].value(); + let ram_pointer_1 = row_1[RamPointer.base_table_index()].value(); + let compare_ram_pointers = ram_pointer_0.cmp(&ram_pointer_1); + + let clk_0 = row_0[CLK.base_table_index()].value(); + let clk_1 = row_1[CLK.base_table_index()].value(); + let compare_clocks = clk_0.cmp(&clk_1); + + compare_ram_pointers.then(compare_clocks) + } + + fn bezout_coefficient_polynomials_coefficients( + ram_table: ArrayView2, + ) -> (Vec, Vec) { + if ram_table.nrows() == 0 { + return (vec![], vec![]); + } + + let linear_poly_with_root = |&r: &BFieldElement| Polynomial::new(vec![-r, BFIELD_ONE]); + + let all_ram_pointers = ram_table.column(RamPointer.base_table_index()); + let unique_ram_pointers = all_ram_pointers.iter().unique(); + let num_unique_ram_pointers = unique_ram_pointers.clone().count(); + + let polynomial_with_ram_pointers_as_roots = unique_ram_pointers + .map(linear_poly_with_root) + .reduce(|accumulator, linear_poly| accumulator * linear_poly) + .unwrap_or_else(Polynomial::zero); + let formal_derivative = polynomial_with_ram_pointers_as_roots.formal_derivative(); + + let (gcd, bezout_poly_0, bezout_poly_1) = + Polynomial::xgcd(polynomial_with_ram_pointers_as_roots, formal_derivative); + + assert!(gcd.is_one()); + assert!(bezout_poly_0.degree() < num_unique_ram_pointers as isize); + assert!(bezout_poly_1.degree() <= num_unique_ram_pointers as isize); + + let mut coefficients_0 = bezout_poly_0.coefficients; + let mut coefficients_1 = bezout_poly_1.coefficients; + coefficients_0.resize(num_unique_ram_pointers, BFIELD_ZERO); + coefficients_1.resize(num_unique_ram_pointers, BFIELD_ZERO); + (coefficients_0, coefficients_1) + } + + /// - Set inverse of RAM pointer difference + /// - Fill in the Bézout coefficients if the RAM pointer changes between two consecutive rows + /// - Collect and return all clock jump differences + fn make_ram_table_consistent( + ram_table: &mut ArrayViewMut2, + mut bezout_coefficient_polynomial_coefficients_0: Vec, + mut bezout_coefficient_polynomial_coefficients_1: Vec, + ) -> Vec { + if ram_table.nrows() == 0 { + assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); + assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len()); + return vec![]; } - // Compute Bézout coefficient polynomials. - let num_of_ramps = pre_processed_ram_table.keys().len(); - let polynomial_with_ramps_as_roots = pre_processed_ram_table.keys().fold( - Polynomial::from_constant(BFieldElement::one()), - |acc, &ramp| acc * Polynomial::new(vec![-ramp, BFieldElement::one()]), // acc·(x - ramp) - ); - let formal_derivative = polynomial_with_ramps_as_roots.formal_derivative(); - let (gcd, bezout_0, bezout_1) = - Polynomial::xgcd(polynomial_with_ramps_as_roots, formal_derivative); - assert!(gcd.is_one(), "Each RAMP value must occur at most once."); - assert!( - bezout_0.degree() < num_of_ramps as isize, - "The Bézout coefficient 0 must be of degree at most {}.", - num_of_ramps - 1 - ); - assert!( - bezout_1.degree() <= num_of_ramps as isize, - "The Bézout coefficient 1 must be of degree at most {num_of_ramps}." - ); - let mut bezout_coefficient_polynomial_coefficients_0 = bezout_0.coefficients; - let mut bezout_coefficient_polynomial_coefficients_1 = bezout_1.coefficients; - bezout_coefficient_polynomial_coefficients_0.resize(num_of_ramps, BFieldElement::zero()); - bezout_coefficient_polynomial_coefficients_1.resize(num_of_ramps, BFieldElement::zero()); let mut current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); let mut current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); - ram_table[[ - 0, - BezoutCoefficientPolynomialCoefficient0.base_table_index(), - ]] = current_bcpc_0; - ram_table[[ - 0, - BezoutCoefficientPolynomialCoefficient1.base_table_index(), - ]] = current_bcpc_1; - - // Move the rows into the Ram Table as contiguous regions of RAMP values. Each such - // contiguous region is sorted by CLK by virtue of the order of the processor's rows. - let mut ram_table_row_idx = 0; - for (ramp, ram_table_rows) in pre_processed_ram_table { - for (clk, previous_instruction, ramv) in ram_table_rows { - let mut ram_table_row = ram_table.row_mut(ram_table_row_idx); - ram_table_row[CLK.base_table_index()] = clk; - ram_table_row[RAMP.base_table_index()] = ramp; - ram_table_row[RAMV.base_table_index()] = ramv; - ram_table_row[PreviousInstruction.base_table_index()] = previous_instruction; - ram_table_row_idx += 1; - } - } - assert_eq!(aet.processor_trace.nrows(), ram_table_row_idx); + ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = + current_bcpc_0; + ram_table.row_mut(0)[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = + current_bcpc_1; - // - Set inverse of RAMP difference. - // - Fill in the Bézout coefficients if the RAMP has changed. - // - Collect all clock jump differences. - // The Ram Table and the Processor Table have the same length. let mut clock_jump_differences = vec![]; - for row_idx in 0..aet.processor_trace.nrows() - 1 { + for row_idx in 0..ram_table.nrows() - 1 { let (mut curr_row, mut next_row) = ram_table.multi_slice_mut((s![row_idx, ..], s![row_idx + 1, ..])); - let ramp_diff = next_row[RAMP.base_table_index()] - curr_row[RAMP.base_table_index()]; - let ramp_diff_inverse = ramp_diff.inverse_or_zero(); - curr_row[InverseOfRampDifference.base_table_index()] = ramp_diff_inverse; + let ramp_diff = + next_row[RamPointer.base_table_index()] - curr_row[RamPointer.base_table_index()]; + let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; - if !ramp_diff.is_zero() { + if ramp_diff.is_zero() { + assert!(!clk_diff.is_zero(), "row_idx = {row_idx}"); + clock_jump_differences.push(clk_diff); + } else { current_bcpc_0 = bezout_coefficient_polynomial_coefficients_0.pop().unwrap(); current_bcpc_1 = bezout_coefficient_polynomial_coefficients_1.pop().unwrap(); } + + curr_row[InverseOfRampDifference.base_table_index()] = ramp_diff.inverse_or_zero(); next_row[BezoutCoefficientPolynomialCoefficient0.base_table_index()] = current_bcpc_0; next_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = current_bcpc_1; - - let clk_diff = next_row[CLK.base_table_index()] - curr_row[CLK.base_table_index()]; - if ramp_diff.is_zero() { - assert!( - !clk_diff.is_zero(), - "All rows must have distinct CLK values, but don't on row with index {row_idx}." - ); - clock_jump_differences.push(clk_diff); - } } assert_eq!(0, bezout_coefficient_polynomial_coefficients_0.len()); assert_eq!(0, bezout_coefficient_polynomial_coefficients_1.len()); - clock_jump_differences } - pub fn pad_trace(mut ram_table: ArrayViewMut2, processor_table_len: usize) { - assert!( - processor_table_len > 0, - "Processor Table must have at least 1 row." - ); - - // Set up indices for relevant sections of the table. - let padded_height = ram_table.nrows(); - let num_padding_rows = padded_height - processor_table_len; - let max_clk_before_padding = processor_table_len - 1; - let max_clk_before_padding_row_idx = ram_table - .rows() - .into_iter() - .enumerate() - .find(|(_, row)| row[CLK.base_table_index()].value() as usize == max_clk_before_padding) - .map(|(idx, _)| idx) - .expect("Ram Table must contain row with clock cycle equal to max cycle."); - let rows_to_move_source_section_start = max_clk_before_padding_row_idx + 1; - let rows_to_move_source_section_end = processor_table_len; - let num_rows_to_move = rows_to_move_source_section_end - rows_to_move_source_section_start; - let rows_to_move_dest_section_start = rows_to_move_source_section_start + num_padding_rows; - let rows_to_move_dest_section_end = rows_to_move_dest_section_start + num_rows_to_move; - let padding_section_start = rows_to_move_source_section_start; - let padding_section_end = padding_section_start + num_padding_rows; - assert_eq!(padded_height, rows_to_move_dest_section_end); - - // Move all rows below the row with highest CLK to the end of the table – if they exist. - if num_rows_to_move > 0 { - let rows_to_move_source_range = - rows_to_move_source_section_start..rows_to_move_source_section_end; - let rows_to_move_dest_range = - rows_to_move_dest_section_start..rows_to_move_dest_section_end; - let rows_to_move = ram_table - .slice(s![rows_to_move_source_range, ..]) - .to_owned(); - rows_to_move.move_into(&mut ram_table.slice_mut(s![rows_to_move_dest_range, ..])); + pub fn pad_trace(mut ram_table: ArrayViewMut2, ram_table_len: usize) { + let last_row_index = ram_table_len.saturating_sub(1); + let mut padding_row = ram_table.row(last_row_index).to_owned(); + padding_row[InstructionType.base_table_index()] = PADDING_INDICATOR; + if ram_table_len == 0 { + padding_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()] = BFIELD_ONE; } - // Fill the created gap with padding rows, i.e., with (adjusted) copies of the last row - // before the gap. This is the padding section. - let mut padding_row_template = ram_table.row(max_clk_before_padding_row_idx).to_owned(); - let ramp_difference_inverse = - padding_row_template[InverseOfRampDifference.base_table_index()]; - padding_row_template[InverseOfRampDifference.base_table_index()] = BFieldElement::zero(); - let mut padding_section = - ram_table.slice_mut(s![padding_section_start..padding_section_end, ..]); + let mut padding_section = ram_table.slice_mut(s![ram_table_len.., ..]); padding_section .axis_iter_mut(Axis(0)) .into_par_iter() - .for_each(|padding_row| padding_row_template.clone().move_into(padding_row)); - - // CLK keeps increasing by 1 also in the padding section. - let clk_range = processor_table_len..padded_height; - let clk_col = Array1::from_iter(clk_range.map(|clk| BFieldElement::new(clk as u64))); - clk_col.move_into(padding_section.slice_mut(s![.., CLK.base_table_index()])); - - // InverseOfRampDifference must be consistent at the padding section's boundaries. - ram_table[[ - max_clk_before_padding_row_idx, - InverseOfRampDifference.base_table_index(), - ]] = BFieldElement::zero(); - if num_rows_to_move > 0 && rows_to_move_dest_section_start > 0 { - let last_row_in_padding_section_idx = rows_to_move_dest_section_start - 1; - ram_table[[ - last_row_in_padding_section_idx, - InverseOfRampDifference.base_table_index(), - ]] = ramp_difference_inverse; - } + .for_each(|mut row| row.assign(&padding_row)); } pub fn extend( @@ -225,21 +218,15 @@ impl RamTable { assert_eq!(EXT_WIDTH, ext_table.ncols()); assert_eq!(base_table.nrows(), ext_table.nrows()); - let clk_weight = challenges[RamClkWeight]; - let ramp_weight = challenges[RamRampWeight]; - let ramv_weight = challenges[RamRamvWeight]; - let previous_instruction_weight = challenges[RamPreviousInstructionWeight]; - let processor_perm_indeterminate = challenges[RamIndeterminate]; - let bezout_relation_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; - let clock_jump_difference_lookup_indeterminate = - challenges[ClockJumpDifferenceLookupIndeterminate]; - let mut running_product_for_perm_arg = PermArg::default_initial(); let mut clock_jump_diff_lookup_log_derivative = LookupArg::default_initial(); // initialize columns establishing Bézout relation - let mut running_product_of_ramp = - bezout_relation_indeterminate - base_table.row(0)[RAMP.base_table_index()]; + let bezout_indeterminate = challenges[RamTableBezoutRelationIndeterminate]; + let clock_jump_difference_lookup_indeterminate = + challenges[ClockJumpDifferenceLookupIndeterminate]; + let mut running_product_ram_pointer = + bezout_indeterminate - base_table.row(0)[RamPointer.base_table_index()]; let mut formal_derivative = XFieldElement::one(); let mut bezout_coefficient_0 = base_table.row(0)[BezoutCoefficientPolynomialCoefficient0.base_table_index()].lift(); @@ -250,46 +237,49 @@ impl RamTable { for row_idx in 0..base_table.nrows() { let current_row = base_table.row(row_idx); let clk = current_row[CLK.base_table_index()]; - let ramp = current_row[RAMP.base_table_index()]; - let ramv = current_row[RAMV.base_table_index()]; - let previous_instruction = current_row[PreviousInstruction.base_table_index()]; - - if let Some(prev_row) = previous_row { - if prev_row[RAMP.base_table_index()] != current_row[RAMP.base_table_index()] { - // accumulate coefficient for Bézout relation, proving new RAMP is unique - let bcpc0 = - current_row[BezoutCoefficientPolynomialCoefficient0.base_table_index()]; - let bcpc1 = - current_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()]; - - formal_derivative = (bezout_relation_indeterminate - ramp) * formal_derivative - + running_product_of_ramp; - running_product_of_ramp *= bezout_relation_indeterminate - ramp; - bezout_coefficient_0 = - bezout_coefficient_0 * bezout_relation_indeterminate + bcpc0; - bezout_coefficient_1 = - bezout_coefficient_1 * bezout_relation_indeterminate + bcpc1; - } else { - // prove that clock jump is directed forward - let clock_jump_difference = - current_row[CLK.base_table_index()] - prev_row[CLK.base_table_index()]; - clock_jump_diff_lookup_log_derivative += - (clock_jump_difference_lookup_indeterminate - clock_jump_difference) - .inverse(); + let instruction_type = current_row[InstructionType.base_table_index()]; + let current_ram_pointer = current_row[RamPointer.base_table_index()]; + let ram_value = current_row[RamValue.base_table_index()]; + + let is_no_padding_row = instruction_type != PADDING_INDICATOR; + + if is_no_padding_row { + if let Some(previous_row) = previous_row { + let previous_ram_pointer = previous_row[RamPointer.base_table_index()]; + if previous_ram_pointer != current_ram_pointer { + // accumulate coefficient for Bézout relation, proving new RAMP is unique + let bcpc0 = + current_row[BezoutCoefficientPolynomialCoefficient0.base_table_index()]; + let bcpc1 = + current_row[BezoutCoefficientPolynomialCoefficient1.base_table_index()]; + + formal_derivative = (bezout_indeterminate - current_ram_pointer) + * formal_derivative + + running_product_ram_pointer; + running_product_ram_pointer *= bezout_indeterminate - current_ram_pointer; + bezout_coefficient_0 = bezout_coefficient_0 * bezout_indeterminate + bcpc0; + bezout_coefficient_1 = bezout_coefficient_1 * bezout_indeterminate + bcpc1; + } else { + let previous_clock = previous_row[CLK.base_table_index()]; + let current_clock = current_row[CLK.base_table_index()]; + let clock_jump_difference = current_clock - previous_clock; + let log_derivative_summand = + clock_jump_difference_lookup_indeterminate - clock_jump_difference; + clock_jump_diff_lookup_log_derivative += log_derivative_summand.inverse(); + } } - } - // permutation argument to Processor Table - let compressed_row_for_permutation_argument = clk * clk_weight - + ramp * ramp_weight - + ramv * ramv_weight - + previous_instruction * previous_instruction_weight; - running_product_for_perm_arg *= - processor_perm_indeterminate - compressed_row_for_permutation_argument; + // permutation argument to Processor Table + let compressed_row = clk * challenges[RamClkWeight] + + instruction_type * challenges[RamInstructionTypeWeight] + + current_ram_pointer * challenges[RamPointerWeight] + + ram_value * challenges[RamValueWeight]; + running_product_for_perm_arg *= challenges[RamIndeterminate] - compressed_row; + } let mut extension_row = ext_table.row_mut(row_idx); extension_row[RunningProductPermArg.ext_table_index()] = running_product_for_perm_arg; - extension_row[RunningProductOfRAMP.ext_table_index()] = running_product_of_ramp; + extension_row[RunningProductOfRAMP.ext_table_index()] = running_product_ram_pointer; extension_row[FormalDerivative.ext_table_index()] = formal_derivative; extension_row[BezoutCoefficient0.ext_table_index()] = bezout_coefficient_0; extension_row[BezoutCoefficient1.ext_table_index()] = bezout_coefficient_1; @@ -304,51 +294,50 @@ impl ExtRamTable { pub fn initial_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let one = circuit_builder.b_constant(1_u32.into()); - - let bezout_challenge = circuit_builder.challenge(RamTableBezoutRelationIndeterminate); - let rppa_challenge = circuit_builder.challenge(RamIndeterminate); - - let clk = circuit_builder.input(BaseRow(CLK.master_base_table_index())); - let ramp = circuit_builder.input(BaseRow(RAMP.master_base_table_index())); - let ramv = circuit_builder.input(BaseRow(RAMV.master_base_table_index())); - let previous_instruction = - circuit_builder.input(BaseRow(PreviousInstruction.master_base_table_index())); - let bcpc0 = circuit_builder.input(BaseRow( - BezoutCoefficientPolynomialCoefficient0.master_base_table_index(), - )); - let bcpc1 = circuit_builder.input(BaseRow( - BezoutCoefficientPolynomialCoefficient1.master_base_table_index(), - )); - let rp = circuit_builder.input(ExtRow(RunningProductOfRAMP.master_ext_table_index())); - let fd = circuit_builder.input(ExtRow(FormalDerivative.master_ext_table_index())); - let bc0 = circuit_builder.input(ExtRow(BezoutCoefficient0.master_ext_table_index())); - let bc1 = circuit_builder.input(ExtRow(BezoutCoefficient1.master_ext_table_index())); - let rppa = circuit_builder.input(ExtRow(RunningProductPermArg.master_ext_table_index())); - let clock_jump_diff_log_derivative = circuit_builder.input(ExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), - )); - - let bezout_coefficient_polynomial_coefficient_0_is_0 = bcpc0; - let bezout_coefficient_0_is_0 = bc0; - let bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1 = bc1 - bcpc1; - let formal_derivative_is_1 = fd - one; - let running_product_polynomial_is_initialized_correctly = - rp - (bezout_challenge - ramp.clone()); - - let clock_jump_diff_log_derivative_is_initialized_correctly = clock_jump_diff_log_derivative - - circuit_builder.x_constant(LookupArg::default_initial()); - - let clk_weight = circuit_builder.challenge(RamClkWeight); - let ramp_weight = circuit_builder.challenge(RamRampWeight); - let ramv_weight = circuit_builder.challenge(RamRamvWeight); - let previous_instruction_weight = circuit_builder.challenge(RamPreviousInstructionWeight); - let compressed_row_for_permutation_argument = clk * clk_weight - + ramp * ramp_weight - + ramv * ramv_weight - + previous_instruction * previous_instruction_weight; - let running_product_permutation_argument_is_initialized_correctly = - rppa - (rppa_challenge - compressed_row_for_permutation_argument); + let challenge = |c| circuit_builder.challenge(c); + let constant = |c| circuit_builder.b_constant(c); + let x_constant = |c| circuit_builder.x_constant(c); + let base_row = |column: RamBaseTableColumn| { + circuit_builder.input(BaseRow(column.master_base_table_index())) + }; + let ext_row = |column: RamExtTableColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; + + let first_row_is_padding_row = base_row(InstructionType) - constant(PADDING_INDICATOR); + let first_row_is_not_padding_row = (base_row(InstructionType) + - constant(INSTRUCTION_TYPE_READ)) + * (base_row(InstructionType) - constant(INSTRUCTION_TYPE_WRITE)); + + let bezout_coefficient_polynomial_coefficient_0_is_0 = + base_row(BezoutCoefficientPolynomialCoefficient0); + let bezout_coefficient_0_is_0 = ext_row(BezoutCoefficient0); + let bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1 = + ext_row(BezoutCoefficient1) - base_row(BezoutCoefficientPolynomialCoefficient1); + let formal_derivative_is_1 = ext_row(FormalDerivative) - constant(1_u32.into()); + let running_product_polynomial_is_initialized_correctly = ext_row(RunningProductOfRAMP) + - challenge(RamTableBezoutRelationIndeterminate) + + base_row(RamPointer); + + let clock_jump_diff_log_derivative_is_default_initial = + ext_row(ClockJumpDifferenceLookupClientLogDerivative) + - x_constant(LookupArg::default_initial()); + + let compressed_row_for_permutation_argument = base_row(CLK) * challenge(RamClkWeight) + + base_row(InstructionType) * challenge(RamInstructionTypeWeight) + + base_row(RamPointer) * challenge(RamPointerWeight) + + base_row(RamValue) * challenge(RamValueWeight); + let running_product_permutation_argument_has_accumulated_first_row = + ext_row(RunningProductPermArg) - challenge(RamIndeterminate) + + compressed_row_for_permutation_argument; + let running_product_permutation_argument_is_default_initial = + ext_row(RunningProductPermArg) - x_constant(PermArg::default_initial()); + + let running_product_permutation_argument_starts_correctly = + running_product_permutation_argument_has_accumulated_first_row + * first_row_is_padding_row + + running_product_permutation_argument_is_default_initial + * first_row_is_not_padding_row; vec![ bezout_coefficient_polynomial_coefficient_0_is_0, @@ -356,8 +345,8 @@ impl ExtRamTable { bezout_coefficient_1_is_bezout_coefficient_polynomial_coefficient_1, running_product_polynomial_is_initialized_correctly, formal_derivative_is_1, - running_product_permutation_argument_is_initialized_correctly, - clock_jump_diff_log_derivative_is_initialized_correctly, + running_product_permutation_argument_starts_correctly, + clock_jump_diff_log_derivative_is_default_initial, ] } @@ -371,131 +360,147 @@ impl ExtRamTable { pub fn transition_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let one = circuit_builder.b_constant(1u32.into()); - - let bezout_challenge = circuit_builder.challenge(RamTableBezoutRelationIndeterminate); - let rppa_challenge = circuit_builder.challenge(RamIndeterminate); - let clk_weight = circuit_builder.challenge(RamClkWeight); - let ramp_weight = circuit_builder.challenge(RamRampWeight); - let ramv_weight = circuit_builder.challenge(RamRamvWeight); - let previous_instruction_weight = circuit_builder.challenge(RamPreviousInstructionWeight); - - let clk = circuit_builder.input(CurrentBaseRow(CLK.master_base_table_index())); - let ramp = circuit_builder.input(CurrentBaseRow(RAMP.master_base_table_index())); - let ramv = circuit_builder.input(CurrentBaseRow(RAMV.master_base_table_index())); - let iord = circuit_builder.input(CurrentBaseRow( - InverseOfRampDifference.master_base_table_index(), - )); - let bcpc0 = circuit_builder.input(CurrentBaseRow( - BezoutCoefficientPolynomialCoefficient0.master_base_table_index(), - )); - let bcpc1 = circuit_builder.input(CurrentBaseRow( - BezoutCoefficientPolynomialCoefficient1.master_base_table_index(), - )); - let rp = - circuit_builder.input(CurrentExtRow(RunningProductOfRAMP.master_ext_table_index())); - let fd = circuit_builder.input(CurrentExtRow(FormalDerivative.master_ext_table_index())); - let bc0 = circuit_builder.input(CurrentExtRow(BezoutCoefficient0.master_ext_table_index())); - let bc1 = circuit_builder.input(CurrentExtRow(BezoutCoefficient1.master_ext_table_index())); - let rppa = circuit_builder.input(CurrentExtRow( - RunningProductPermArg.master_ext_table_index(), - )); - let clock_jump_diff_log_derivative = circuit_builder.input(CurrentExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), - )); - - let clk_next = circuit_builder.input(NextBaseRow(CLK.master_base_table_index())); - let ramp_next = circuit_builder.input(NextBaseRow(RAMP.master_base_table_index())); - let ramv_next = circuit_builder.input(NextBaseRow(RAMV.master_base_table_index())); - let previous_instruction_next = - circuit_builder.input(NextBaseRow(PreviousInstruction.master_base_table_index())); - let bcpc0_next = circuit_builder.input(NextBaseRow( - BezoutCoefficientPolynomialCoefficient0.master_base_table_index(), - )); - let bcpc1_next = circuit_builder.input(NextBaseRow( - BezoutCoefficientPolynomialCoefficient1.master_base_table_index(), - )); - let rp_next = - circuit_builder.input(NextExtRow(RunningProductOfRAMP.master_ext_table_index())); - let fd_next = circuit_builder.input(NextExtRow(FormalDerivative.master_ext_table_index())); - let bc0_next = - circuit_builder.input(NextExtRow(BezoutCoefficient0.master_ext_table_index())); - let bc1_next = - circuit_builder.input(NextExtRow(BezoutCoefficient1.master_ext_table_index())); - let rppa_next = - circuit_builder.input(NextExtRow(RunningProductPermArg.master_ext_table_index())); - let clock_jump_diff_log_derivative_next = circuit_builder.input(NextExtRow( - ClockJumpDifferenceLookupClientLogDerivative.master_ext_table_index(), - )); - - let ramp_diff = ramp_next.clone() - ramp; - let ramp_changes = ramp_diff.clone() * iord.clone(); - - // iord is 0 or iord is the inverse of (ramp' - ramp) - let iord_is_0_or_iord_is_inverse_of_ramp_diff = iord * (ramp_changes.clone() - one.clone()); - - // (ramp' - ramp) is zero or iord is the inverse of (ramp' - ramp) - let ramp_diff_is_0_or_iord_is_inverse_of_ramp_diff = - ramp_diff.clone() * (ramp_changes.clone() - one.clone()); - - // (ramp doesn't change) and (previous instruction is not write_mem) - // implies the ramv doesn't change - let op_code_write_mem = circuit_builder.b_constant(Instruction::WriteMem.opcode_b()); - let ramp_changes_or_write_mem_or_ramv_stays = (one.clone() - ramp_changes.clone()) - * (op_code_write_mem - previous_instruction_next.clone()) - * (ramv_next.clone() - ramv); - - let bcbp0_only_changes_if_ramp_changes = - (one.clone() - ramp_changes.clone()) * (bcpc0_next.clone() - bcpc0); - - let bcbp1_only_changes_if_ramp_changes = - (one.clone() - ramp_changes.clone()) * (bcpc1_next.clone() - bcpc1); - - let running_product_ramp_updates_correctly = ramp_diff.clone() - * (rp_next.clone() - rp.clone() * (bezout_challenge.clone() - ramp_next.clone())) - + (one.clone() - ramp_changes.clone()) * (rp_next - rp.clone()); - - let formal_derivative_updates_correctly = ramp_diff.clone() - * (fd_next.clone() - rp - (bezout_challenge.clone() - ramp_next.clone()) * fd.clone()) - + (one.clone() - ramp_changes.clone()) * (fd_next - fd); - - let bezout_coefficient_0_is_constructed_correctly = ramp_diff.clone() + let constant = |c| circuit_builder.b_constant(c); + let challenge = |c| circuit_builder.challenge(c); + let curr_base_row = |column: RamBaseTableColumn| { + circuit_builder.input(CurrentBaseRow(column.master_base_table_index())) + }; + let curr_ext_row = |column: RamExtTableColumn| { + circuit_builder.input(CurrentExtRow(column.master_ext_table_index())) + }; + let next_base_row = |column: RamBaseTableColumn| { + circuit_builder.input(NextBaseRow(column.master_base_table_index())) + }; + let next_ext_row = |column: RamExtTableColumn| { + circuit_builder.input(NextExtRow(column.master_ext_table_index())) + }; + + let one = constant(1_u32.into()); + + let bezout_challenge = challenge(RamTableBezoutRelationIndeterminate); + + let clock = curr_base_row(CLK); + let ram_pointer = curr_base_row(RamPointer); + let ram_value = curr_base_row(RamValue); + let instruction_type = curr_base_row(InstructionType); + let inverse_of_ram_pointer_difference = curr_base_row(InverseOfRampDifference); + let bcpc0 = curr_base_row(BezoutCoefficientPolynomialCoefficient0); + let bcpc1 = curr_base_row(BezoutCoefficientPolynomialCoefficient1); + + let running_product_ram_pointer = curr_ext_row(RunningProductOfRAMP); + let fd = curr_ext_row(FormalDerivative); + let bc0 = curr_ext_row(BezoutCoefficient0); + let bc1 = curr_ext_row(BezoutCoefficient1); + let rppa = curr_ext_row(RunningProductPermArg); + let clock_jump_diff_log_derivative = + curr_ext_row(ClockJumpDifferenceLookupClientLogDerivative); + + let clock_next = next_base_row(CLK); + let ram_pointer_next = next_base_row(RamPointer); + let ram_value_next = next_base_row(RamValue); + let instruction_type_next = next_base_row(InstructionType); + let bcpc0_next = next_base_row(BezoutCoefficientPolynomialCoefficient0); + let bcpc1_next = next_base_row(BezoutCoefficientPolynomialCoefficient1); + + let running_product_ram_pointer_next = next_ext_row(RunningProductOfRAMP); + let fd_next = next_ext_row(FormalDerivative); + let bc0_next = next_ext_row(BezoutCoefficient0); + let bc1_next = next_ext_row(BezoutCoefficient1); + let rppa_next = next_ext_row(RunningProductPermArg); + let clock_jump_diff_log_derivative_next = + next_ext_row(ClockJumpDifferenceLookupClientLogDerivative); + + let next_row_is_padding_row = + instruction_type_next.clone() - constant(PADDING_INDICATOR).clone(); + let if_current_row_is_padding_row_then_next_row_is_padding_row = (instruction_type.clone() + - constant(INSTRUCTION_TYPE_READ)) + * (instruction_type - constant(INSTRUCTION_TYPE_WRITE)) + * next_row_is_padding_row.clone(); + + let ram_pointer_difference = ram_pointer_next.clone() - ram_pointer; + let ram_pointer_changes = one.clone() + - ram_pointer_difference.clone() * inverse_of_ram_pointer_difference.clone(); + + let iord_is_0_or_iord_is_inverse_of_ram_pointer_difference = + inverse_of_ram_pointer_difference * ram_pointer_changes.clone(); + + let ram_pointer_difference_is_0_or_iord_is_inverse_of_ram_pointer_difference = + ram_pointer_difference.clone() * ram_pointer_changes.clone(); + + let ram_pointer_changes_or_write_mem_or_ram_value_stays = ram_pointer_changes.clone() + * (constant(INSTRUCTION_TYPE_WRITE) - instruction_type_next.clone()) + * (ram_value_next.clone() - ram_value); + + let bcbp0_only_changes_if_ram_pointer_changes = + ram_pointer_changes.clone() * (bcpc0_next.clone() - bcpc0); + + let bcbp1_only_changes_if_ram_pointer_changes = + ram_pointer_changes.clone() * (bcpc1_next.clone() - bcpc1); + + let running_product_ram_pointer_updates_correctly = ram_pointer_difference.clone() + * (running_product_ram_pointer_next.clone() + - running_product_ram_pointer.clone() + * (bezout_challenge.clone() - ram_pointer_next.clone())) + + ram_pointer_changes.clone() + * (running_product_ram_pointer_next - running_product_ram_pointer.clone()); + + let formal_derivative_updates_correctly = ram_pointer_difference.clone() + * (fd_next.clone() + - running_product_ram_pointer + - (bezout_challenge.clone() - ram_pointer_next.clone()) * fd.clone()) + + ram_pointer_changes.clone() * (fd_next - fd); + + let bezout_coefficient_0_is_constructed_correctly = ram_pointer_difference.clone() * (bc0_next.clone() - bezout_challenge.clone() * bc0.clone() - bcpc0_next) - + (one.clone() - ramp_changes.clone()) * (bc0_next - bc0); + + ram_pointer_changes.clone() * (bc0_next - bc0); - let bezout_coefficient_1_is_constructed_correctly = ramp_diff.clone() + let bezout_coefficient_1_is_constructed_correctly = ram_pointer_difference.clone() * (bc1_next.clone() - bezout_challenge * bc1.clone() - bcpc1_next) - + (one.clone() - ramp_changes.clone()) * (bc1_next - bc1); - - let compressed_row_for_permutation_argument = clk_next.clone() * clk_weight - + ramp_next * ramp_weight - + ramv_next * ramv_weight - + previous_instruction_next * previous_instruction_weight; - let rppa_updates_correctly = - rppa_next - rppa * (rppa_challenge - compressed_row_for_permutation_argument); - - // The running sum of the logarithmic derivative for the clock jump difference Lookup - // Argument accumulates a summand of `clk_diff` if and only if the `ramp` does not change. - // Expressed differently: - // - the `ramp` changes or the log derivative accumulates a summand, and - // - the `ramp` does not change or the log derivative does not change. - let log_derivative_remains = - clock_jump_diff_log_derivative_next.clone() - clock_jump_diff_log_derivative.clone(); - let clk_diff = clk_next - clk; - let log_derivative_accumulates = (clock_jump_diff_log_derivative_next - - clock_jump_diff_log_derivative) - * (circuit_builder.challenge(ClockJumpDifferenceLookupIndeterminate) - clk_diff) + + ram_pointer_changes.clone() * (bc1_next - bc1); + + let compressed_row = clock_next.clone() * challenge(RamClkWeight) + + ram_pointer_next * challenge(RamPointerWeight) + + ram_value_next * challenge(RamValueWeight) + + instruction_type_next.clone() * challenge(RamInstructionTypeWeight); + let rppa_accumulates_next_row = + rppa_next.clone() - rppa.clone() * (challenge(RamIndeterminate) - compressed_row); + + let next_row_is_not_padding_row = (instruction_type_next.clone() + - constant(INSTRUCTION_TYPE_READ)) + * (instruction_type_next - constant(INSTRUCTION_TYPE_WRITE)); + let rppa_remains_unchanged = rppa_next - rppa; + + let rppa_updates_correctly = rppa_accumulates_next_row * next_row_is_padding_row.clone() + + rppa_remains_unchanged * next_row_is_not_padding_row.clone(); + + let clock_difference = clock_next - clock; + let log_derivative_accumulates = (clock_jump_diff_log_derivative_next.clone() + - clock_jump_diff_log_derivative.clone()) + * (challenge(ClockJumpDifferenceLookupIndeterminate) - clock_difference) - one.clone(); + let log_derivative_remains = + clock_jump_diff_log_derivative_next - clock_jump_diff_log_derivative.clone(); + + let log_derivative_accumulates_or_ram_pointer_changes_or_next_row_is_padding_row = + log_derivative_accumulates * ram_pointer_changes.clone() * next_row_is_padding_row; + let log_derivative_remains_or_ram_pointer_doesnt_change = + log_derivative_remains.clone() * ram_pointer_difference.clone(); + let log_derivative_remains_or_next_row_is_not_padding_row = + log_derivative_remains * next_row_is_not_padding_row; + let log_derivative_updates_correctly = - (one - ramp_changes) * log_derivative_accumulates + ramp_diff * log_derivative_remains; + log_derivative_accumulates_or_ram_pointer_changes_or_next_row_is_padding_row + + log_derivative_remains_or_ram_pointer_doesnt_change + + log_derivative_remains_or_next_row_is_not_padding_row; vec![ - iord_is_0_or_iord_is_inverse_of_ramp_diff, - ramp_diff_is_0_or_iord_is_inverse_of_ramp_diff, - ramp_changes_or_write_mem_or_ramv_stays, - bcbp0_only_changes_if_ramp_changes, - bcbp1_only_changes_if_ramp_changes, - running_product_ramp_updates_correctly, + if_current_row_is_padding_row_then_next_row_is_padding_row, + iord_is_0_or_iord_is_inverse_of_ram_pointer_difference, + ram_pointer_difference_is_0_or_iord_is_inverse_of_ram_pointer_difference, + ram_pointer_changes_or_write_mem_or_ram_value_stays, + bcbp0_only_changes_if_ram_pointer_changes, + bcbp1_only_changes_if_ram_pointer_changes, + running_product_ram_pointer_updates_correctly, formal_derivative_updates_correctly, bezout_coefficient_0_is_constructed_correctly, bezout_coefficient_1_is_constructed_correctly, @@ -507,14 +512,14 @@ impl ExtRamTable { pub fn terminal_constraints( circuit_builder: &ConstraintCircuitBuilder, ) -> Vec> { - let one = circuit_builder.b_constant(1_u32.into()); - - let rp = circuit_builder.input(ExtRow(RunningProductOfRAMP.master_ext_table_index())); - let fd = circuit_builder.input(ExtRow(FormalDerivative.master_ext_table_index())); - let bc0 = circuit_builder.input(ExtRow(BezoutCoefficient0.master_ext_table_index())); - let bc1 = circuit_builder.input(ExtRow(BezoutCoefficient1.master_ext_table_index())); + let constant = |c: u32| circuit_builder.b_constant(c.into()); + let ext_row = |column: RamExtTableColumn| { + circuit_builder.input(ExtRow(column.master_ext_table_index())) + }; - let bezout_relation_holds = bc0 * rp + bc1 * fd - one; + let bezout_relation_holds = ext_row(BezoutCoefficient0) * ext_row(RunningProductOfRAMP) + + ext_row(BezoutCoefficient1) * ext_row(FormalDerivative) + - constant(1); vec![bezout_relation_holds] } @@ -522,8 +527,18 @@ impl ExtRamTable { #[cfg(test)] pub(crate) mod tests { + use proptest_arbitrary_interop::arb; + use test_strategy::proptest; + use super::*; + #[proptest] + fn ram_table_call_can_be_converted_to_table_rows( + #[strategy(arb())] ram_table_call: RamTableCall, + ) { + let _ = ram_table_call.to_table_rows(); + } + pub fn constraints_evaluate_to_zero( master_base_trace_table: ArrayView2, master_ext_trace_table: ArrayView2, diff --git a/triton-vm/src/table/table_column.rs b/triton-vm/src/table/table_column.rs index 4fae34fb0..b24e3f90e 100644 --- a/triton-vm/src/table/table_column.rs +++ b/triton-vm/src/table/table_column.rs @@ -104,7 +104,6 @@ pub enum ProgramExtTableColumn { pub enum ProcessorBaseTableColumn { CLK, IsPadding, - PreviousInstruction, IP, CI, NIA, @@ -141,8 +140,6 @@ pub enum ProcessorBaseTableColumn { HV3, HV4, HV5, - RAMP, - RAMV, /// The number of clock jump differences of magnitude `CLK` in all memory-like tables. ClockJumpDifferenceLookupMultiplicity, } @@ -199,9 +196,16 @@ pub enum OpStackExtTableColumn { #[derive(Display, Debug, Clone, Copy, PartialEq, Eq, EnumIter, EnumCount, Hash)] pub enum RamBaseTableColumn { CLK, - PreviousInstruction, - RAMP, - RAMV, + + /// Is [`INSTRUCTION_TYPE_READ`] for instruction `read_mem` and [`INSTRUCTION_TYPE_WRITE`] + /// for instruction `write_mem`. For padding rows, this is set to [`PADDING_INDICATOR`]. + /// + /// [`INSTRUCTION_TYPE_READ`]: crate::table::ram_table::INSTRUCTION_TYPE_READ + /// [`INSTRUCTION_TYPE_WRITE`]: crate::table::ram_table::INSTRUCTION_TYPE_WRITE + /// [`PADDING_INDICATOR`]: crate::table::ram_table::PADDING_INDICATOR + InstructionType, + RamPointer, + RamValue, InverseOfRampDifference, BezoutCoefficientPolynomialCoefficient0, BezoutCoefficientPolynomialCoefficient1, diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index a5baee12d..c19412327 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -32,6 +32,7 @@ use crate::table::hash_table::PermutationTrace; use crate::table::op_stack_table::OpStackTableEntry; use crate::table::processor_table; use crate::table::processor_table::ProcessorTraceRow; +use crate::table::ram_table::RamTableCall; use crate::table::table_column::*; use crate::table::u32_table::U32TableEntry; use crate::vm::CoProcessorCall::*; @@ -72,12 +73,6 @@ pub struct VMState<'pgm> { /// Current instruction's address in program memory pub instruction_pointer: usize, - /// The instruction that was executed last - pub previous_instruction: Option, - - /// RAM pointer - pub ram_pointer: u64, - /// The current state of the one, global Sponge that can be manipulated using instructions /// `SpongeInit`, `SpongeAbsorb`, and `SpongeSqueeze`. Instruction `SpongeInit` resets the /// Sponge state. @@ -103,6 +98,8 @@ pub enum CoProcessorCall { U32Call(U32TableEntry), OpStackCall(OpStackTableEntry), + + RamCall(RamTableCall), } impl<'pgm> VMState<'pgm> { @@ -129,8 +126,6 @@ impl<'pgm> VMState<'pgm> { jump_stack: vec![], cycle_count: 0, instruction_pointer: 0, - previous_instruction: Default::default(), - ram_pointer: 0, sponge_state: Default::default(), halting: false, } @@ -143,7 +138,8 @@ impl<'pgm> VMState<'pgm> { }; match current_instruction { - Pop(_) | Divine(_) | Dup(_) | Swap(_) | ReadIo(_) | WriteIo(_) => { + Pop(_) | Divine(_) | Dup(_) | Swap(_) | ReadMem(_) | WriteMem(_) | ReadIo(_) + | WriteIo(_) => { let arg_val: u64 = current_instruction.arg().unwrap().value(); hvs[0] = BFieldElement::new(arg_val % 2); hvs[1] = BFieldElement::new((arg_val >> 1) % 2); @@ -194,11 +190,6 @@ impl<'pgm> VMState<'pgm> { /// Perform the state transition as a mutable operation on `self`. pub fn step(&mut self) -> Result> { - // trying to read past the end of the program doesn't change the previous instruction - if let Ok(instruction) = self.current_instruction() { - self.previous_instruction = Some(instruction); - } - self.start_recording_op_stack_calls(); let mut co_processor_calls = match self.current_instruction()? { Pop(n) => self.pop(n)?, @@ -213,8 +204,8 @@ impl<'pgm> VMState<'pgm> { Return => self.return_from_call()?, Recurse => self.recurse()?, Assert => self.assert()?, - ReadMem => self.read_mem(), - WriteMem => self.write_mem()?, + ReadMem(n) => self.read_mem(n)?, + WriteMem(n) => self.write_mem(n)?, Hash => self.hash()?, SpongeInit => self.sponge_init(), SpongeAbsorb => self.sponge_absorb()?, @@ -376,25 +367,54 @@ impl<'pgm> VMState<'pgm> { vec![] } - fn read_mem(&mut self) -> Vec { - let ram_pointer = self.op_stack.peek_at(ST0); - self.ram_pointer = ram_pointer.value(); + fn read_mem(&mut self, n: NumberOfWords) -> Result> { + let mut ram_pointer = self.op_stack.pop()?; - let ram_value = self.memory_get(&ram_pointer); - self.op_stack.push(ram_value); + let mut ram_values = vec![]; + for _ in 0..n.num_words() { + ram_pointer.decrement(); + let ram_value = self.ram.get(&ram_pointer).copied().unwrap_or(BFIELD_ZERO); + self.op_stack.push(ram_value); + ram_values.push(ram_value); + } + ram_values.reverse(); - self.instruction_pointer += 1; - vec![] + self.op_stack.push(ram_pointer); + + let ram_table_call = RamTableCall { + clk: self.cycle_count, + ram_pointer, + is_write: false, + values: ram_values, + }; + + self.instruction_pointer += 2; + Ok(vec![RamCall(ram_table_call)]) } - fn write_mem(&mut self) -> Result> { - let ram_pointer = self.op_stack.peek_at(ST1); - let ram_value = self.op_stack.pop()?; - self.ram_pointer = ram_pointer.value(); - self.ram.insert(ram_pointer, ram_value); + fn write_mem(&mut self, n: NumberOfWords) -> Result> { + let mut ram_pointer = self.op_stack.pop()?; - self.instruction_pointer += 1; - Ok(vec![]) + let mut ram_values = vec![]; + for _ in 0..n.num_words() { + let ram_value = self.op_stack.pop()?; + self.ram.insert(ram_pointer, ram_value); + ram_values.push(ram_value); + ram_pointer.increment(); + } + + self.op_stack.push(ram_pointer); + + ram_pointer -= n.into(); + let ram_table_call = RamTableCall { + clk: self.cycle_count, + ram_pointer, + is_write: true, + values: ram_values, + }; + + self.instruction_pointer += 2; + Ok(vec![RamCall(ram_table_call)]) } fn hash(&mut self) -> Result> { @@ -718,16 +738,10 @@ impl<'pgm> VMState<'pgm> { use ProcessorBaseTableColumn::*; let mut processor_row = Array1::zeros(processor_table::BASE_WIDTH); - let previous_instruction = match self.previous_instruction { - Some(instruction) => instruction.opcode_b(), - None => BFIELD_ZERO, - }; let current_instruction = self.current_instruction().unwrap_or(Nop); let helper_variables = self.derive_helper_variables(); - let ram_pointer = self.ram_pointer.into(); processor_row[CLK.base_table_index()] = (self.cycle_count as u64).into(); - processor_row[PreviousInstruction.base_table_index()] = previous_instruction; processor_row[IP.base_table_index()] = (self.instruction_pointer as u32).into(); processor_row[CI.base_table_index()] = current_instruction.opcode_b(); processor_row[NIA.base_table_index()] = self.next_instruction_or_argument(); @@ -764,8 +778,6 @@ impl<'pgm> VMState<'pgm> { processor_row[HV3.base_table_index()] = helper_variables[3]; processor_row[HV4.base_table_index()] = helper_variables[4]; processor_row[HV5.base_table_index()] = helper_variables[5]; - processor_row[RAMP.base_table_index()] = ram_pointer; - processor_row[RAMV.base_table_index()] = self.memory_get(&ram_pointer); processor_row } @@ -836,11 +848,6 @@ impl<'pgm> VMState<'pgm> { maybe_jump_stack_element.ok_or(anyhow!(JumpStackIsEmpty)) } - fn memory_get(&self, memory_address: &BFieldElement) -> BFieldElement { - let maybe_memory_value = self.ram.get(memory_address).copied(); - maybe_memory_value.unwrap_or_else(BFieldElement::zero) - } - fn pop_secret_digest(&mut self) -> Result<[BFieldElement; DIGEST_LENGTH]> { let digest = self .secret_digests @@ -1072,9 +1079,14 @@ pub(crate) mod tests { } pub(crate) fn test_program_for_write_mem_read_mem() -> ProgramAndInput { - ProgramAndInput::without_input( - triton_program!(push 2 push 1 write_mem push 0 pop 1 read_mem assert halt), - ) + ProgramAndInput::without_input(triton_program! { + push 3 push 1 push 2 // _ 3 1 2 + push 7 // _ 3 1 2 7 + write_mem 3 // _ 10 + read_mem 2 // _ 3 1 8 + pop 1 // _ 3 1 + assert halt // _ 3 + }) } pub(crate) fn test_program_for_hash() -> ProgramAndInput { @@ -1625,15 +1637,14 @@ pub(crate) mod tests { // Not all addresses are read to have different access patterns: // - Some addresses are read before written to. // - Other addresses are written to before read. - for memory_address in memory_addresses.iter().take(num_memory_accesses / 4) { - instructions.extend(triton_asm!(push {memory_address} read_mem push 0 eq assert pop 1)); + for address in memory_addresses.iter().take(num_memory_accesses / 4) { + let address = address.value() + 1; + instructions.extend(triton_asm!(push {address} read_mem 1 pop 1 push 0 eq assert)); } // Write everything to RAM. - for (memory_address, memory_value) in memory_addresses.iter().zip_eq(memory_values.iter()) { - instructions.extend(triton_asm!( - push {memory_address} push {memory_value} write_mem pop 1 - )); + for (address, value) in memory_addresses.iter().zip_eq(memory_values.iter()) { + instructions.extend(triton_asm!(push {value} push {address} write_mem 1 pop 1)); } // Read back in random order and check that the values did not change. @@ -1643,11 +1654,10 @@ pub(crate) mod tests { reading_permutation.swap(i, j); } for idx in reading_permutation { - let memory_address = memory_addresses[idx]; - let memory_value = memory_values[idx]; - instructions.extend(triton_asm!( - push {memory_address} read_mem push {memory_value} eq assert pop 1 - )); + let address = memory_addresses[idx].value() + 1; + let value = memory_values[idx]; + instructions + .extend(triton_asm!(push {address} read_mem 1 pop 1 push {value} eq assert)); } // Overwrite half the values with new ones. @@ -1657,12 +1667,11 @@ pub(crate) mod tests { writing_permutation.swap(i, j); } for idx in 0..num_memory_accesses / 2 { - let memory_address = memory_addresses[writing_permutation[idx]]; + let address = memory_addresses[writing_permutation[idx]]; let new_memory_value = rng.gen(); memory_values[writing_permutation[idx]] = new_memory_value; - instructions.extend(triton_asm!( - push {memory_address} push {new_memory_value} write_mem pop 1 - )); + instructions + .extend(triton_asm!(push {new_memory_value} push {address} write_mem 1 pop 1)); } // Read back all, i.e., unchanged and overwritten values in (different from before) random @@ -1673,11 +1682,10 @@ pub(crate) mod tests { reading_permutation.swap(i, j); } for idx in reading_permutation { - let memory_address = memory_addresses[idx]; - let memory_value = memory_values[idx]; - instructions.extend(triton_asm!( - push {memory_address} read_mem push {memory_value} eq assert pop 1 - )); + let address = memory_addresses[idx].value() + 1; + let value = memory_values[idx]; + instructions + .extend(triton_asm!(push {address} read_mem 1 pop 1 push {value} eq assert)); } let program = triton_program! { { &instructions } halt }; @@ -1932,23 +1940,23 @@ pub(crate) mod tests { #[test] fn run_tvm_basic_ram_read_write() { let program = triton_program!( - push 5 push 6 write_mem pop 1 - push 15 push 16 write_mem pop 1 - push 5 read_mem pop 2 - push 15 read_mem pop 2 - push 5 push 7 write_mem pop 1 - push 15 read_mem - push 5 read_mem + push 8 push 5 write_mem 1 pop 1 // write 8 to address 5 + push 18 push 15 write_mem 1 pop 1 // write 18 to address 15 + push 6 read_mem 1 pop 2 // read from address 5 + push 16 read_mem 1 pop 2 // read from address 15 + push 7 push 5 write_mem 1 pop 1 // write 7 to address 5 + push 16 read_mem 1 // _ 18 15 + push 6 read_mem 1 // _ 18 15 7 5 halt ); let terminal_state = program .debug_terminal_state([].into(), [].into(), None, None) .unwrap(); - assert_eq!(BFieldElement::new(7), terminal_state.op_stack.peek_at(ST0)); - assert_eq!(BFieldElement::new(5), terminal_state.op_stack.peek_at(ST1)); - assert_eq!(BFieldElement::new(16), terminal_state.op_stack.peek_at(ST2)); - assert_eq!(BFieldElement::new(15), terminal_state.op_stack.peek_at(ST3)); + assert_eq!(BFieldElement::new(5), terminal_state.op_stack.peek_at(ST0)); + assert_eq!(BFieldElement::new(7), terminal_state.op_stack.peek_at(ST1)); + assert_eq!(BFieldElement::new(15), terminal_state.op_stack.peek_at(ST2)); + assert_eq!(BFieldElement::new(18), terminal_state.op_stack.peek_at(ST3)); } #[test] @@ -1958,17 +1966,20 @@ pub(crate) mod tests { // ↓ // _ 0 0 | push 0 // _ 0 0 | 0 - write_mem // _ 0 0 | - push 5 // _ 0 0 | 5 - swap 1 // _ 0 5 | 0 - push 3 // _ 0 5 | 0 3 - swap 1 // _ 0 5 | 3 0 + write_mem 1 // _ 0 1 | + push 5 // _ 0 1 | 5 + swap 1 // _ 0 5 | 1 + push 3 // _ 0 5 | 1 3 + swap 1 // _ 0 5 | 3 1 pop 1 // _ 0 5 | 3 - write_mem // _ 0 5 | - read_mem // _ 0 5 | 3 + write_mem 1 // _ 0 4 | + read_mem 1 // _ 0 5 | 3 swap 2 // _ 3 5 | 0 pop 1 // _ 3 5 | - read_mem // _ 3 5 | 3 + swap 1 // _ 5 3 | + push 1 // _ 5 3 | 1 + add // _ 5 4 | + read_mem 1 // _ 5 5 | 3 halt ); @@ -1977,7 +1988,7 @@ pub(crate) mod tests { .unwrap(); assert_eq!(BFieldElement::new(3), terminal_state.op_stack.peek_at(ST0)); assert_eq!(BFieldElement::new(5), terminal_state.op_stack.peek_at(ST1)); - assert_eq!(BFieldElement::new(3), terminal_state.op_stack.peek_at(ST2)); + assert_eq!(BFieldElement::new(5), terminal_state.op_stack.peek_at(ST2)); } #[test] @@ -2077,14 +2088,14 @@ pub(crate) mod tests { #[test] fn read_mem_unitialized() { - let program = triton_program!(read_mem halt); + let program = triton_program!(read_mem 3 halt); let (aet, _) = program.trace_execution([].into(), [].into()).unwrap(); assert_eq!(2, aet.processor_trace.nrows()); } #[test] fn read_non_deterministically_initialized_ram_at_address_0() { - let program = triton_program!(read_mem write_io 1 halt); + let program = triton_program!(push 1 read_mem 1 pop 1 write_io 1 halt); let mut initial_ram = HashMap::new(); initial_ram.insert(0_u64.into(), 42_u64.into()); @@ -2100,17 +2111,19 @@ pub(crate) mod tests { prove_with_low_security_level(&program, public_input, secret_input, &mut None); } - #[test] - fn read_non_deterministically_initialized_ram_at_random_address() { - let random_address = thread_rng().gen_range(1..2_u64.pow(16)); + #[proptest(cases = 10)] + fn read_non_deterministically_initialized_ram_at_random_address( + #[strategy(arb())] address: BFieldElement, + #[strategy(arb())] value: BFieldElement, + ) { let program = triton_program!( - read_mem write_io 1 - push {random_address} read_mem write_io 1 + read_mem 1 swap 1 write_io 1 + push {address.value() + 1} read_mem 1 pop 1 write_io 1 halt ); let mut initial_ram = HashMap::new(); - initial_ram.insert(random_address.into(), 1337_u64.into()); + initial_ram.insert(address, value); let public_input = PublicInput::new(vec![]); let secret_input = NonDeterminism::new(vec![]).with_ram(initial_ram); @@ -2119,7 +2132,7 @@ pub(crate) mod tests { .run(public_input.clone(), secret_input.clone()) .unwrap(); assert_eq!(0, public_output[0].value()); - assert_eq!(1337, public_output[1].value()); + assert_eq!(value, public_output[1]); prove_with_low_security_level(&program, public_input, secret_input, &mut None); } @@ -2184,12 +2197,8 @@ pub(crate) mod tests { let Ok(err) = err.downcast::() else { panic!("Sudoku verifier must fail with InstructionError on bad Sudoku."); }; - let AssertionFailed(ip, _, _) = err else { + let AssertionFailed(_, _, _) = err else { panic!("Sudoku verifier must fail with AssertionFailed on bad Sudoku."); }; - assert_eq!( - 15, ip, - "Sudoku verifier must fail on line 15 on bad Sudoku." - ); } }