Skip to content

Commit

Permalink
feat: only change VM state if instruction execution will work
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Dec 1, 2023
1 parent 868f49d commit d7fbb3f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 26 deletions.
23 changes: 19 additions & 4 deletions triton-vm/src/op_stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ impl OpStack {
Ok(element)
}

pub(crate) fn assert_is_u32(&self, stack_element: OpStackElement) -> Result<()> {
let element = self.peek_at(stack_element);
match element.value() <= u32::MAX as u64 {
true => Ok(()),
false => Err(FailedU32Conversion(element)),
}
}

pub(crate) fn pop_u32(&mut self) -> Result<u32> {
let element = self.pop()?;
let element = element
Expand All @@ -118,15 +126,23 @@ impl OpStack {
self.stack[top_of_stack_index - stack_element_index]
}

pub(crate) fn peek_at_top_extension_field_element(&self) -> XFieldElement {
let coefficient_0 = self.peek_at(ST0);
let coefficient_1 = self.peek_at(ST1);
let coefficient_2 = self.peek_at(ST2);
let coefficients = [coefficient_0, coefficient_1, coefficient_2];
XFieldElement::new(coefficients)
}

pub(crate) fn swap_top_with(&mut self, stack_element: OpStackElement) {
let stack_element_index = usize::from(stack_element);
let top_of_stack_index = self.stack.len() - 1;
self.stack
.swap(top_of_stack_index, top_of_stack_index - stack_element_index);
}

pub(crate) fn is_too_shallow(&self) -> bool {
self.stack.len() < OpStackElement::COUNT
pub(crate) fn would_be_too_shallow(&self, stack_delta: i32) -> bool {
self.stack.len() as i32 + stack_delta < OpStackElement::COUNT as i32
}

/// The address of the next free address of the op-stack. Equivalent to the current length of
Expand Down Expand Up @@ -660,8 +676,7 @@ mod tests {
assert!(op_stack.pointer().value() as usize == op_stack.stack.len());

// verify underflow
let _ = op_stack.pop().expect("can't pop");
assert!(op_stack.is_too_shallow());
assert!(op_stack.would_be_too_shallow(-1));
}

#[test]
Expand Down
72 changes: 50 additions & 22 deletions triton-vm/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,14 @@ impl VMState {

/// Perform the state transition as a mutable operation on `self`.
pub fn step(&mut self) -> Result<Vec<CoProcessorCall>> {
let current_instruction = self.current_instruction()?;
let op_stack_delta = current_instruction.op_stack_size_influence();
if self.op_stack.would_be_too_shallow(op_stack_delta) {
return Err(OpStackTooShallow);
}

self.start_recording_op_stack_calls();
let mut co_processor_calls = match self.current_instruction()? {
let mut co_processor_calls = match current_instruction {
Pop(n) => self.pop(n)?,
Push(field_element) => self.push(field_element),
Divine(n) => self.divine(n)?,
Expand Down Expand Up @@ -235,10 +241,6 @@ impl VMState {
let op_stack_calls = self.stop_recording_op_stack_calls();
co_processor_calls.extend(op_stack_calls);

if self.op_stack.is_too_shallow() {
return Err(OpStackTooShallow);
}

self.cycle_count += 1;

Ok(co_processor_calls)
Expand Down Expand Up @@ -293,11 +295,12 @@ impl VMState {
}

fn divine(&mut self, n: NumberOfWords) -> Result<Vec<CoProcessorCall>> {
for i in 0..n.num_words() {
let element = self
.secret_individual_tokens
.pop_front()
.ok_or(EmptySecretInput(i))?;
let input_len = self.secret_individual_tokens.len();
if input_len < n.num_words() {
return Err(EmptySecretInput(input_len));
}
for _ in 0..n.num_words() {
let element = self.secret_individual_tokens.pop_front().unwrap();
self.op_stack.push(element);
}

Expand Down Expand Up @@ -359,11 +362,11 @@ impl VMState {
}

fn assert(&mut self) -> Result<Vec<CoProcessorCall>> {
let top_of_stack = self.op_stack.pop()?;

let top_of_stack = self.op_stack.peek_at(ST0);
if !top_of_stack.is_one() {
return Err(AssertionFailed);
}
let _ = self.op_stack.pop()?;

self.instruction_pointer += 1;
Ok(vec![])
Expand All @@ -379,7 +382,7 @@ impl VMState {
self.start_recording_ram_calls();
let mut ram_pointer = self.op_stack.pop()?;
for _ in 0..n.num_words() {
let ram_value = self.ram_read(ram_pointer)?;
let ram_value = self.ram_read(ram_pointer);
self.op_stack.push(ram_value);
ram_pointer.decrement();
}
Expand All @@ -405,7 +408,7 @@ impl VMState {
Ok(ram_calls)
}

fn ram_read(&mut self, ram_pointer: BFieldElement) -> Result<BFieldElement> {
fn ram_read(&mut self, ram_pointer: BFieldElement) -> BFieldElement {
let ram_value = self.ram.get(&ram_pointer).copied().unwrap_or(BFIELD_ZERO);

let ram_table_call = RamTableCall {
Expand All @@ -416,7 +419,7 @@ impl VMState {
};
self.ram_calls.push(ram_table_call);

Ok(ram_value)
ram_value
}

fn ram_write(&mut self, ram_pointer: BFieldElement, ram_value: BFieldElement) {
Expand Down Expand Up @@ -485,6 +488,11 @@ impl VMState {
}

fn divine_sibling(&mut self) -> Result<Vec<CoProcessorCall>> {
if self.secret_digests.is_empty() {
return Err(EmptySecretDigestInput);
}
self.op_stack.assert_is_u32(ST5)?;

let known_digest = self.op_stack.pop_multiple()?;
let node_index = self.op_stack.pop_u32()?;

Expand Down Expand Up @@ -538,10 +546,11 @@ impl VMState {
}

fn invert(&mut self) -> Result<Vec<CoProcessorCall>> {
let top_of_stack = self.op_stack.pop()?;
let top_of_stack = self.op_stack.peek_at(ST0);
if top_of_stack.is_zero() {
return Err(InverseOfZero);
}
let _ = self.op_stack.pop()?;
self.op_stack.push(top_of_stack.inverse());
self.instruction_pointer += 1;
Ok(vec![])
Expand Down Expand Up @@ -572,6 +581,8 @@ impl VMState {
}

fn lt(&mut self) -> Result<Vec<CoProcessorCall>> {
self.op_stack.assert_is_u32(ST0)?;
self.op_stack.assert_is_u32(ST1)?;
let lhs = self.op_stack.pop_u32()?;
let rhs = self.op_stack.pop_u32()?;
let lt: u32 = (lhs < rhs).into();
Expand All @@ -585,6 +596,8 @@ impl VMState {
}

fn and(&mut self) -> Result<Vec<CoProcessorCall>> {
self.op_stack.assert_is_u32(ST0)?;
self.op_stack.assert_is_u32(ST1)?;
let lhs = self.op_stack.pop_u32()?;
let rhs = self.op_stack.pop_u32()?;
let and = lhs & rhs;
Expand All @@ -598,6 +611,8 @@ impl VMState {
}

fn xor(&mut self) -> Result<Vec<CoProcessorCall>> {
self.op_stack.assert_is_u32(ST0)?;
self.op_stack.assert_is_u32(ST1)?;
let lhs = self.op_stack.pop_u32()?;
let rhs = self.op_stack.pop_u32()?;
let xor = lhs ^ rhs;
Expand All @@ -614,10 +629,12 @@ impl VMState {
}

fn log_2_floor(&mut self) -> Result<Vec<CoProcessorCall>> {
let top_of_stack = self.op_stack.pop_u32()?;
self.op_stack.assert_is_u32(ST0)?;
let top_of_stack = self.op_stack.peek_at(ST0);
if top_of_stack.is_zero() {
return Err(LogarithmOfZero);
}
let top_of_stack = self.op_stack.pop_u32()?;
let log_2_floor = top_of_stack.ilog2();
self.op_stack.push(log_2_floor.into());

Expand All @@ -629,6 +646,7 @@ impl VMState {
}

fn pow(&mut self) -> Result<Vec<CoProcessorCall>> {
self.op_stack.assert_is_u32(ST1)?;
let base = self.op_stack.pop()?;
let exponent = self.op_stack.pop_u32()?;
let base_pow_exponent = base.mod_pow(exponent.into());
Expand All @@ -643,11 +661,15 @@ impl VMState {
}

fn div_mod(&mut self) -> Result<Vec<CoProcessorCall>> {
let numerator = self.op_stack.pop_u32()?;
let denominator = self.op_stack.pop_u32()?;
self.op_stack.assert_is_u32(ST0)?;
self.op_stack.assert_is_u32(ST1)?;
let denominator = self.op_stack.peek_at(ST1);
if denominator.is_zero() {
return Err(DivisionByZero);
}

let numerator = self.op_stack.pop_u32()?;
let denominator = self.op_stack.pop_u32()?;
let quotient = numerator / denominator;
let remainder = numerator % denominator;

Expand All @@ -666,6 +688,7 @@ impl VMState {
}

fn pop_count(&mut self) -> Result<Vec<CoProcessorCall>> {
self.op_stack.assert_is_u32(ST0)?;
let top_of_stack = self.op_stack.pop_u32()?;
let pop_count = top_of_stack.count_ones();
self.op_stack.push(pop_count.into());
Expand Down Expand Up @@ -694,11 +717,12 @@ impl VMState {
}

fn x_invert(&mut self) -> Result<Vec<CoProcessorCall>> {
let top_of_stack = self.op_stack.pop_extension_field_element()?;
let top_of_stack = self.op_stack.peek_at_top_extension_field_element();
if top_of_stack.is_zero() {
return Err(InverseOfZero);
}
let inverse = top_of_stack.inverse();
let _ = self.op_stack.pop_extension_field_element()?;
self.op_stack.push_extension_field_element(inverse);
self.instruction_pointer += 1;
Ok(vec![])
Expand All @@ -724,8 +748,12 @@ impl VMState {
}

fn read_io(&mut self, n: NumberOfWords) -> Result<Vec<CoProcessorCall>> {
for i in 0..n.num_words() {
let read_element = self.public_input.pop_front().ok_or(EmptyPublicInput(i))?;
let input_len = self.public_input.len();
if input_len < n.num_words() {
return Err(EmptyPublicInput(input_len));
}
for _ in 0..n.num_words() {
let read_element = self.public_input.pop_front().unwrap();
self.op_stack.push(read_element);
}

Expand Down

0 comments on commit d7fbb3f

Please sign in to comment.