From d7fbb3fdd64b783a735c9879678ff3e5bcfc7489 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Fri, 1 Dec 2023 10:50:18 +0100 Subject: [PATCH] feat: only change VM state if instruction execution will work --- triton-vm/src/op_stack.rs | 23 ++++++++++--- triton-vm/src/vm.rs | 72 +++++++++++++++++++++++++++------------ 2 files changed, 69 insertions(+), 26 deletions(-) diff --git a/triton-vm/src/op_stack.rs b/triton-vm/src/op_stack.rs index ec203b2e8..578b67c42 100644 --- a/triton-vm/src/op_stack.rs +++ b/triton-vm/src/op_stack.rs @@ -94,6 +94,14 @@ impl OpStack { Ok(element) } + pub(crate) fn assert_is_u32(&self, stack_element: OpStackElement) -> Result<()> { + let element = self.peek_at(stack_element); + match element.value() <= u32::MAX as u64 { + true => Ok(()), + false => Err(FailedU32Conversion(element)), + } + } + pub(crate) fn pop_u32(&mut self) -> Result { let element = self.pop()?; let element = element @@ -118,6 +126,14 @@ impl OpStack { self.stack[top_of_stack_index - stack_element_index] } + pub(crate) fn peek_at_top_extension_field_element(&self) -> XFieldElement { + let coefficient_0 = self.peek_at(ST0); + let coefficient_1 = self.peek_at(ST1); + let coefficient_2 = self.peek_at(ST2); + let coefficients = [coefficient_0, coefficient_1, coefficient_2]; + XFieldElement::new(coefficients) + } + pub(crate) fn swap_top_with(&mut self, stack_element: OpStackElement) { let stack_element_index = usize::from(stack_element); let top_of_stack_index = self.stack.len() - 1; @@ -125,8 +141,8 @@ impl OpStack { .swap(top_of_stack_index, top_of_stack_index - stack_element_index); } - pub(crate) fn is_too_shallow(&self) -> bool { - self.stack.len() < OpStackElement::COUNT + pub(crate) fn would_be_too_shallow(&self, stack_delta: i32) -> bool { + self.stack.len() as i32 + stack_delta < OpStackElement::COUNT as i32 } /// The address of the next free address of the op-stack. Equivalent to the current length of @@ -660,8 +676,7 @@ mod tests { assert!(op_stack.pointer().value() as usize == op_stack.stack.len()); // verify underflow - let _ = op_stack.pop().expect("can't pop"); - assert!(op_stack.is_too_shallow()); + assert!(op_stack.would_be_too_shallow(-1)); } #[test] diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index 9694c0eb5..b66430edb 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -191,8 +191,14 @@ impl VMState { /// Perform the state transition as a mutable operation on `self`. pub fn step(&mut self) -> Result> { + let current_instruction = self.current_instruction()?; + let op_stack_delta = current_instruction.op_stack_size_influence(); + if self.op_stack.would_be_too_shallow(op_stack_delta) { + return Err(OpStackTooShallow); + } + self.start_recording_op_stack_calls(); - let mut co_processor_calls = match self.current_instruction()? { + let mut co_processor_calls = match current_instruction { Pop(n) => self.pop(n)?, Push(field_element) => self.push(field_element), Divine(n) => self.divine(n)?, @@ -235,10 +241,6 @@ impl VMState { let op_stack_calls = self.stop_recording_op_stack_calls(); co_processor_calls.extend(op_stack_calls); - if self.op_stack.is_too_shallow() { - return Err(OpStackTooShallow); - } - self.cycle_count += 1; Ok(co_processor_calls) @@ -293,11 +295,12 @@ impl VMState { } fn divine(&mut self, n: NumberOfWords) -> Result> { - for i in 0..n.num_words() { - let element = self - .secret_individual_tokens - .pop_front() - .ok_or(EmptySecretInput(i))?; + let input_len = self.secret_individual_tokens.len(); + if input_len < n.num_words() { + return Err(EmptySecretInput(input_len)); + } + for _ in 0..n.num_words() { + let element = self.secret_individual_tokens.pop_front().unwrap(); self.op_stack.push(element); } @@ -359,11 +362,11 @@ impl VMState { } fn assert(&mut self) -> Result> { - let top_of_stack = self.op_stack.pop()?; - + let top_of_stack = self.op_stack.peek_at(ST0); if !top_of_stack.is_one() { return Err(AssertionFailed); } + let _ = self.op_stack.pop()?; self.instruction_pointer += 1; Ok(vec![]) @@ -379,7 +382,7 @@ impl VMState { self.start_recording_ram_calls(); let mut ram_pointer = self.op_stack.pop()?; for _ in 0..n.num_words() { - let ram_value = self.ram_read(ram_pointer)?; + let ram_value = self.ram_read(ram_pointer); self.op_stack.push(ram_value); ram_pointer.decrement(); } @@ -405,7 +408,7 @@ impl VMState { Ok(ram_calls) } - fn ram_read(&mut self, ram_pointer: BFieldElement) -> Result { + fn ram_read(&mut self, ram_pointer: BFieldElement) -> BFieldElement { let ram_value = self.ram.get(&ram_pointer).copied().unwrap_or(BFIELD_ZERO); let ram_table_call = RamTableCall { @@ -416,7 +419,7 @@ impl VMState { }; self.ram_calls.push(ram_table_call); - Ok(ram_value) + ram_value } fn ram_write(&mut self, ram_pointer: BFieldElement, ram_value: BFieldElement) { @@ -485,6 +488,11 @@ impl VMState { } fn divine_sibling(&mut self) -> Result> { + if self.secret_digests.is_empty() { + return Err(EmptySecretDigestInput); + } + self.op_stack.assert_is_u32(ST5)?; + let known_digest = self.op_stack.pop_multiple()?; let node_index = self.op_stack.pop_u32()?; @@ -538,10 +546,11 @@ impl VMState { } fn invert(&mut self) -> Result> { - let top_of_stack = self.op_stack.pop()?; + let top_of_stack = self.op_stack.peek_at(ST0); if top_of_stack.is_zero() { return Err(InverseOfZero); } + let _ = self.op_stack.pop()?; self.op_stack.push(top_of_stack.inverse()); self.instruction_pointer += 1; Ok(vec![]) @@ -572,6 +581,8 @@ impl VMState { } fn lt(&mut self) -> Result> { + self.op_stack.assert_is_u32(ST0)?; + self.op_stack.assert_is_u32(ST1)?; let lhs = self.op_stack.pop_u32()?; let rhs = self.op_stack.pop_u32()?; let lt: u32 = (lhs < rhs).into(); @@ -585,6 +596,8 @@ impl VMState { } fn and(&mut self) -> Result> { + self.op_stack.assert_is_u32(ST0)?; + self.op_stack.assert_is_u32(ST1)?; let lhs = self.op_stack.pop_u32()?; let rhs = self.op_stack.pop_u32()?; let and = lhs & rhs; @@ -598,6 +611,8 @@ impl VMState { } fn xor(&mut self) -> Result> { + self.op_stack.assert_is_u32(ST0)?; + self.op_stack.assert_is_u32(ST1)?; let lhs = self.op_stack.pop_u32()?; let rhs = self.op_stack.pop_u32()?; let xor = lhs ^ rhs; @@ -614,10 +629,12 @@ impl VMState { } fn log_2_floor(&mut self) -> Result> { - let top_of_stack = self.op_stack.pop_u32()?; + self.op_stack.assert_is_u32(ST0)?; + let top_of_stack = self.op_stack.peek_at(ST0); if top_of_stack.is_zero() { return Err(LogarithmOfZero); } + let top_of_stack = self.op_stack.pop_u32()?; let log_2_floor = top_of_stack.ilog2(); self.op_stack.push(log_2_floor.into()); @@ -629,6 +646,7 @@ impl VMState { } fn pow(&mut self) -> Result> { + self.op_stack.assert_is_u32(ST1)?; let base = self.op_stack.pop()?; let exponent = self.op_stack.pop_u32()?; let base_pow_exponent = base.mod_pow(exponent.into()); @@ -643,11 +661,15 @@ impl VMState { } fn div_mod(&mut self) -> Result> { - let numerator = self.op_stack.pop_u32()?; - let denominator = self.op_stack.pop_u32()?; + self.op_stack.assert_is_u32(ST0)?; + self.op_stack.assert_is_u32(ST1)?; + let denominator = self.op_stack.peek_at(ST1); if denominator.is_zero() { return Err(DivisionByZero); } + + let numerator = self.op_stack.pop_u32()?; + let denominator = self.op_stack.pop_u32()?; let quotient = numerator / denominator; let remainder = numerator % denominator; @@ -666,6 +688,7 @@ impl VMState { } fn pop_count(&mut self) -> Result> { + self.op_stack.assert_is_u32(ST0)?; let top_of_stack = self.op_stack.pop_u32()?; let pop_count = top_of_stack.count_ones(); self.op_stack.push(pop_count.into()); @@ -694,11 +717,12 @@ impl VMState { } fn x_invert(&mut self) -> Result> { - let top_of_stack = self.op_stack.pop_extension_field_element()?; + let top_of_stack = self.op_stack.peek_at_top_extension_field_element(); if top_of_stack.is_zero() { return Err(InverseOfZero); } let inverse = top_of_stack.inverse(); + let _ = self.op_stack.pop_extension_field_element()?; self.op_stack.push_extension_field_element(inverse); self.instruction_pointer += 1; Ok(vec![]) @@ -724,8 +748,12 @@ impl VMState { } fn read_io(&mut self, n: NumberOfWords) -> Result> { - for i in 0..n.num_words() { - let read_element = self.public_input.pop_front().ok_or(EmptyPublicInput(i))?; + let input_len = self.public_input.len(); + if input_len < n.num_words() { + return Err(EmptyPublicInput(input_len)); + } + for _ in 0..n.num_words() { + let read_element = self.public_input.pop_front().unwrap(); self.op_stack.push(read_element); }