From c3bbe62f73adb3cf627374e2bb8740fbbdd6e794 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Mon, 1 Jul 2024 14:20:21 -0400 Subject: [PATCH 1/9] Refactor memory chip and program chip --- vm/src/cpu/trace.rs | 25 +---- vm/src/memory/mod.rs | 4 +- vm/src/memory/offline_checker/air.rs | 8 +- vm/src/memory/offline_checker/bridge.rs | 4 +- vm/src/memory/offline_checker/columns.rs | 8 +- vm/src/memory/offline_checker/mod.rs | 95 +++++++++++++--- vm/src/memory/offline_checker/trace.rs | 68 +++++------ vm/src/memory/tests.rs | 42 +++---- vm/src/program/mod.rs | 28 +++-- vm/src/program/tests/mod.rs | 15 ++- vm/src/program/trace.rs | 22 ++-- vm/src/vm/mod.rs | 137 +++++++---------------- 12 files changed, 215 insertions(+), 241 deletions(-) diff --git a/vm/src/cpu/trace.rs b/vm/src/cpu/trace.rs index 1037785107..521e8b748e 100644 --- a/vm/src/cpu/trace.rs +++ b/vm/src/cpu/trace.rs @@ -12,7 +12,7 @@ use afs_chips::{ is_equal_vec::IsEqualVecAir, is_zero::IsZeroAir, sub_chip::LocalTraceInstructions, }; -use crate::{field_arithmetic::FieldArithmeticAir, memory::OpType}; +use crate::{field_arithmetic::FieldArithmeticAir, memory::OpType, vm::VirtualMachine}; use super::{ columns::{CpuAuxCols, CpuCols, CpuIoCols, MemoryAccessCols}, @@ -138,25 +138,6 @@ impl FieldExtensionOperation { } } -pub struct ProgramExecution { - pub program: Vec>, - pub trace_rows: Vec>, - pub execution_frequencies: Vec, - pub memory_accesses: Vec>, - pub arithmetic_ops: Vec>, -} - -impl ProgramExecution { - pub fn trace(&self, options: CpuOptions) -> RowMajorMatrix { - let rows: Vec = self - .trace_rows - .iter() - .flat_map(|row| row.flatten(options)) - .collect(); - RowMajorMatrix::new(rows, CpuCols::::get_width(options)) - } -} - struct Memory { data: HashMap>, log: Vec>, @@ -265,8 +246,8 @@ impl Error for ExecutionError {} impl CpuAir { pub fn generate_program_execution( &self, - program: Vec>, - ) -> Result, ExecutionError> { + vm: VirtualMachine, + ) -> RowMajorMatrix { let mut rows = vec![]; let mut execution_frequencies = vec![F::zero(); program.len()]; let mut arithmetic_operations = vec![]; diff --git a/vm/src/memory/mod.rs b/vm/src/memory/mod.rs index 67a7bfe6b5..0bb36ba6c6 100644 --- a/vm/src/memory/mod.rs +++ b/vm/src/memory/mod.rs @@ -9,10 +9,10 @@ pub enum OpType { } #[derive(Clone, Debug)] -pub struct MemoryAccess { +pub struct MemoryAccess { pub timestamp: usize, pub op_type: OpType, pub address_space: F, pub address: F, - pub data: Vec, + pub data: [F; WORD_SIZE], } diff --git a/vm/src/memory/offline_checker/air.rs b/vm/src/memory/offline_checker/air.rs index 619cb337dc..2f7ac91708 100644 --- a/vm/src/memory/offline_checker/air.rs +++ b/vm/src/memory/offline_checker/air.rs @@ -16,17 +16,17 @@ use afs_chips::{ sub_chip::{AirConfig, SubAir}, }; -impl AirConfig for OfflineChecker { +impl AirConfig for OfflineChecker { type Cols = OfflineCheckerCols; } -impl BaseAir for OfflineChecker { +impl BaseAir for OfflineChecker { fn width(&self) -> usize { self.air_width() } } -impl Air for OfflineChecker +impl Air for OfflineChecker where AB::M: Clone, { @@ -105,7 +105,7 @@ where next_cols.is_equal_data_aux.prods, next_cols.is_equal_data_aux.invs, ); - let is_equal_data_air = IsEqualVecAir::new(self.data_len); + let is_equal_data_air = IsEqualVecAir::new(WORD_SIZE); SubAir::eval( &is_equal_data_air, diff --git a/vm/src/memory/offline_checker/bridge.rs b/vm/src/memory/offline_checker/bridge.rs index c3fd79c539..9bcb8c4c77 100644 --- a/vm/src/memory/offline_checker/bridge.rs +++ b/vm/src/memory/offline_checker/bridge.rs @@ -12,7 +12,7 @@ use afs_chips::is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupl use afs_chips::is_less_than_tuple::IsLessThanTupleAir; use afs_chips::sub_chip::SubAirBridge; -impl SubAirBridge for OfflineChecker { +impl SubAirBridge for OfflineChecker { /// Receives operations (clk, op_type, addr_space, pointer, data) fn receives(&self, col_indices: OfflineCheckerCols) -> Vec> { let op_cols: Vec> = iter::once(col_indices.clk) @@ -50,7 +50,7 @@ impl SubAirBridge for OfflineChecker { } } -impl AirBridge for OfflineChecker { +impl AirBridge for OfflineChecker { fn receives(&self) -> Vec> { let num_cols = self.air_width(); let all_cols = (0..num_cols).collect::>(); diff --git a/vm/src/memory/offline_checker/columns.rs b/vm/src/memory/offline_checker/columns.rs index 2b83cf6329..bddf4f223f 100644 --- a/vm/src/memory/offline_checker/columns.rs +++ b/vm/src/memory/offline_checker/columns.rs @@ -97,7 +97,7 @@ where flattened } - pub fn from_slice(slc: &[T], oc: &OfflineChecker) -> Self { + pub fn from_slice(slc: &[T], oc: &OfflineChecker) -> Self { assert!(slc.len() == oc.air_width()); let mem_width = oc.mem_width(); @@ -114,11 +114,11 @@ where is_equal_addr_space_aux: IsEqualAuxCols::from_slice(&slc[8 + mem_width..9 + mem_width]), is_equal_pointer_aux: IsEqualAuxCols::from_slice(&slc[9 + mem_width..10 + mem_width]), is_equal_data_aux: IsEqualVecAuxCols::from_slice( - &slc[10 + mem_width..10 + mem_width + 2 * oc.data_len], - oc.data_len, + &slc[10 + mem_width..10 + mem_width + 2 * WORD_SIZE], + WORD_SIZE, ), lt_aux: IsLessThanTupleAuxCols::from_slice( - &slc[10 + mem_width + 2 * oc.data_len..], + &slc[10 + mem_width + 2 * WORD_SIZE..], oc.addr_clk_limb_bits.clone(), oc.decomp, 3, diff --git a/vm/src/memory/offline_checker/mod.rs b/vm/src/memory/offline_checker/mod.rs index 19b5acd899..72079bf54a 100644 --- a/vm/src/memory/offline_checker/mod.rs +++ b/vm/src/memory/offline_checker/mod.rs @@ -1,43 +1,102 @@ +use std::collections::HashMap; + use afs_chips::is_less_than_tuple::columns::IsLessThanTupleAuxCols; +use p3_field::PrimeField32; + +use crate::memory::OpType; + +use super::MemoryAccess; mod air; mod bridge; mod columns; mod trace; -pub struct OfflineChecker { - data_len: usize, +pub struct OfflineChecker { addr_clk_limb_bits: Vec, decomp: usize, } -impl OfflineChecker { +impl OfflineChecker { + pub fn mem_width(&self) -> usize { + // 1 for addr_space, 1 for pointer, data_len for data + 2 + WORD_SIZE + } + + pub fn air_width(&self) -> usize { + 10 + self.mem_width() + + 2 * WORD_SIZE + + IsLessThanTupleAuxCols::::get_width( + self.addr_clk_limb_bits.clone(), + self.decomp, + 3, + ) + } +} + +pub struct MemoryChip { + pub air: OfflineChecker, + pub accesses: Vec>, + memory: HashMap<(F, F), [F; WORD_SIZE]>, + last_timestamp: Option, +} + +impl MemoryChip { pub fn new( - data_len: usize, addr_space_limb_bits: usize, pointer_limb_bits: usize, clk_limb_bits: usize, decomp: usize, ) -> Self { Self { - data_len, - addr_clk_limb_bits: vec![addr_space_limb_bits, pointer_limb_bits, clk_limb_bits], - decomp, + air: OfflineChecker { + addr_clk_limb_bits: vec![addr_space_limb_bits, pointer_limb_bits, clk_limb_bits], + decomp, + }, + accesses: vec![], + memory: HashMap::new(), + last_timestamp: None, } } - pub fn mem_width(&self) -> usize { - // 1 for addr_space, 1 for pointer, data_len for data - 2 + self.data_len + pub fn read_word(&mut self, timestamp: usize, address_space: F, address: F) -> [F; WORD_SIZE] { + assert!(address_space != F::zero()); + assert!(timestamp % WORD_SIZE == 0); + if let Some(last_timestamp) = self.last_timestamp { + assert!(timestamp > last_timestamp); + } + self.last_timestamp = Some(timestamp); + let data = self.memory[&(address_space, address)]; + self.accesses.push(MemoryAccess { + timestamp, + op_type: OpType::Read, + address_space, + address, + data: data.clone(), + }); + data } - pub fn air_width(&self) -> usize { - 10 + self.mem_width() - + 2 * self.data_len - + IsLessThanTupleAuxCols::::get_width( - self.addr_clk_limb_bits.clone(), - self.decomp, - 3, - ) + pub fn write_word( + &mut self, + timestamp: usize, + address_space: F, + address: F, + data: [F; WORD_SIZE], + ) { + assert!(address_space != F::zero()); + assert!(timestamp % WORD_SIZE == 0); + if let Some(last_timestamp) = self.last_timestamp { + assert!(timestamp > last_timestamp); + } + self.last_timestamp = Some(timestamp); + self.memory.insert((address_space, address), data); + self.accesses.push(MemoryAccess { + timestamp, + op_type: OpType::Write, + address_space, + address, + data: data.clone(), + }); } } diff --git a/vm/src/memory/offline_checker/trace.rs b/vm/src/memory/offline_checker/trace.rs index fefdaa1c75..a8a912248d 100644 --- a/vm/src/memory/offline_checker/trace.rs +++ b/vm/src/memory/offline_checker/trace.rs @@ -7,13 +7,13 @@ use p3_matrix::dense::RowMajorMatrix; use crate::memory::{MemoryAccess, OpType}; -use super::OfflineChecker; +use super::MemoryChip; use afs_chips::is_equal_vec::IsEqualVecAir; use afs_chips::is_less_than_tuple::IsLessThanTupleAir; use afs_chips::range_gate::RangeCheckerGateChip; use afs_chips::sub_chip::LocalTraceInstructions; -impl OfflineChecker { +impl MemoryChip { /// Each row in the trace follow the same order as the Cols struct: /// [clk, mem_row, op_type, same_addr_space, same_pointer, same_addr, same_data, lt_bit, is_valid, is_equal_addr_space_aux, is_equal_pointer_aux, is_equal_data_aux, lt_aux] /// @@ -21,13 +21,13 @@ impl OfflineChecker { /// The trace is sorted by addr (addr_space and pointer) and then by clk, so every addr has a block of consective rows in the trace with the following structure /// A row is added to the trace for every read/write operation with the corresponding data /// The trace is padded at the end to be of height trace_degree - pub fn generate_trace( - &self, - mut ops: Vec>, + pub fn generate_trace( + &mut self, range_checker: Arc, trace_degree: usize, ) -> RowMajorMatrix { - ops.sort_by_key(|op| (op.address_space, op.address, op.timestamp)); + self.accesses + .sort_by_key(|op| (op.address_space, op.address, op.timestamp)); let mut rows: Vec = vec![]; @@ -36,25 +36,25 @@ impl OfflineChecker { op_type: OpType::Read, address_space: F::zero(), address: F::zero(), - data: vec![F::zero(); self.data_len], + data: [F::zero(); WORD_SIZE], }; - if !ops.is_empty() { - rows.extend(self.generate_trace_row::( + if !self.accesses.is_empty() { + rows.extend(self.generate_trace_row( true, 1, - &ops[0], + &self.accesses[0], &dummy_op, range_checker.clone(), )); } - for i in 1..ops.len() { - rows.extend(self.generate_trace_row::( + for i in 1..self.accesses.len() { + rows.extend(self.generate_trace_row( false, 1, - &ops[i], - &ops[i - 1], + &self.accesses[i], + &self.accesses[i - 1], range_checker.clone(), )); } @@ -62,18 +62,18 @@ impl OfflineChecker { // Ensure that trace degree is a power of two assert!(trace_degree > 0 && trace_degree & (trace_degree - 1) == 0); - if ops.len() < trace_degree { - rows.extend(self.generate_trace_row::( + if self.accesses.len() < trace_degree { + rows.extend(self.generate_trace_row( false, 0, &dummy_op, - &ops[ops.len() - 1], + &self.accesses[self.accesses.len() - 1], range_checker.clone(), )); } - for _i in 1..(trace_degree - ops.len()) { - rows.extend(self.generate_trace_row::( + for _i in 1..(trace_degree - self.accesses.len()) { + rows.extend(self.generate_trace_row( false, 0, &dummy_op, @@ -82,15 +82,15 @@ impl OfflineChecker { )); } - RowMajorMatrix::new(rows, self.air_width()) + RowMajorMatrix::new(rows, self.air.air_width()) } - pub fn generate_trace_row( + pub fn generate_trace_row( &self, is_first_row: bool, is_valid: u8, - curr_op: &MemoryAccess, - prev_op: &MemoryAccess, + curr_op: &MemoryAccess, + prev_op: &MemoryAccess, range_checker: Arc, ) -> Vec { let mut row: Vec = vec![]; @@ -136,11 +136,11 @@ impl OfflineChecker { let is_equal_addr_space_air = IsEqualAir {}; let is_equal_pointer_air = IsEqualAir {}; - let is_equal_data_air = IsEqualVecAir::new(self.data_len); + let is_equal_data_air = IsEqualVecAir::new(WORD_SIZE); let lt_air = IsLessThanTupleAir::new( range_checker.bus_index(), - self.addr_clk_limb_bits.clone(), - self.decomp, + self.air.addr_clk_limb_bits.clone(), + self.air.decomp, ); let is_equal_addr_space_aux = is_equal_addr_space_air @@ -150,8 +150,8 @@ impl OfflineChecker { .generate_trace_row((prev_op.address, curr_op.address)) .flatten()[3]; let is_equal_data_aux = is_equal_data_air - .generate_trace_row((prev_op.data.clone(), curr_op.data.clone())) - .flatten()[2 * self.data_len..] + .generate_trace_row((prev_op.data.to_vec(), curr_op.data.to_vec())) + .flatten()[2 * WORD_SIZE..] .to_vec(); let lt_aux: Vec = lt_air .generate_trace_row(( @@ -175,17 +175,19 @@ impl OfflineChecker { row.extend(is_equal_data_aux); row.extend(lt_aux); + let mem_width = self.air.mem_width(); + if is_first_row { // same_addr_space should be 0 - row[2 + self.mem_width()] = F::zero(); + row[2 + mem_width] = F::zero(); // same_pointer should be 0 - row[3 + self.mem_width()] = F::zero(); + row[3 + mem_width] = F::zero(); // same_addr should be 0 - row[4 + self.mem_width()] = F::zero(); + row[4 + mem_width] = F::zero(); // same_data should be 0 - row[5 + self.mem_width()] = F::zero(); + row[5 + mem_width] = F::zero(); // lt_bit should be 1 - row[6 + self.mem_width()] = F::one(); + row[6 + mem_width] = F::one(); } row diff --git a/vm/src/memory/tests.rs b/vm/src/memory/tests.rs index e9ac1055cf..7f0b0f4af1 100644 --- a/vm/src/memory/tests.rs +++ b/vm/src/memory/tests.rs @@ -12,9 +12,9 @@ use p3_matrix::dense::RowMajorMatrix; use crate::cpu::{MEMORY_BUS, RANGE_CHECKER_BUS}; -use super::{offline_checker::OfflineChecker, MemoryAccess, OpType}; +use super::{offline_checker::MemoryChip, MemoryAccess, OpType}; -const DATA_LEN: usize = 3; +const WORD_SIZE: usize = 3; const ADDR_SPACE_LIMB_BITS: usize = 8; const POINTER_LIMB_BITS: usize = 8; const CLK_LIMB_BITS: usize = 8; @@ -26,14 +26,14 @@ const TRACE_DEGREE: usize = 16; #[test] fn test_offline_checker() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let offline_checker = OfflineChecker::new( - DATA_LEN, + let memory_chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + offline_checker.mem_width(), true, MEMORY_BUS); + let requester = DummyInteractionAir::new(2 + memory_chip.air.mem_width(), true, MEMORY_BUS); + let ops: Vec> = vec![ MemoryAccess { @@ -152,34 +152,22 @@ fn test_offline_checker() { #[test] fn test_offline_checker_negative_invalid_read() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let offline_checker = OfflineChecker::new( - DATA_LEN, + let mut memory_chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + offline_checker.mem_width(), true, MEMORY_BUS); + let requester = DummyInteractionAir::new(2 + memory_chip.air.mem_width(), true, MEMORY_BUS); // should fail because we can't read before writing - let ops: Vec> = vec![MemoryAccess { - timestamp: 0, - op_type: OpType::Read, - address_space: BabyBear::zero(), - address: BabyBear::zero(), - data: vec![ - BabyBear::from_canonical_usize(0), - BabyBear::from_canonical_usize(0), - BabyBear::from_canonical_usize(0), - ], - }]; + memory_chip.write_word(0, BabyBear::one(), BabyBear::zero(), [BabyBear::zero(), BabyBear::zero(), BabyBear::zero()]); - let offline_checker_trace = - offline_checker.generate_trace(ops.clone(), range_checker.clone(), TRACE_DEGREE); + let memory_trace = memory_chip.generate_trace(range_checker.clone(), TRACE_DEGREE); let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( - ops.iter() - .flat_map(|op: &MemoryAccess| { + memory_chip.accesses.iter() + .flat_map(|op: &MemoryAccess| { iter::once(BabyBear::one()) .chain(iter::once(BabyBear::from_canonical_usize(op.timestamp))) .chain(iter::once(BabyBear::from_canonical_u8(op.op_type as u8))) @@ -191,7 +179,7 @@ fn test_offline_checker_negative_invalid_read() { iter::repeat_with(|| { iter::repeat(BabyBear::zero()).take(1 + requester.field_width()) }) - .take(TRACE_DEGREE - ops.len()) + .take(TRACE_DEGREE - memory_chip.accesses.len()) .flatten(), ) .collect(), @@ -203,8 +191,8 @@ fn test_offline_checker_negative_invalid_read() { }); assert_eq!( run_simple_test_no_pis( - vec![&offline_checker, &range_checker.air, &requester], - vec![offline_checker_trace, range_checker_trace, requester_trace], + vec![&memory_chip.air, &range_checker.air, &requester], + vec![memory_trace, range_checker_trace, requester_trace], ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" @@ -215,7 +203,7 @@ fn test_offline_checker_negative_invalid_read() { fn test_offline_checker_negative_data_mismatch() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); let offline_checker = OfflineChecker::new( - DATA_LEN, + WORD_SIZE, ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, diff --git a/vm/src/program/mod.rs b/vm/src/program/mod.rs index 0abb2d2a48..a79dae56cf 100644 --- a/vm/src/program/mod.rs +++ b/vm/src/program/mod.rs @@ -11,24 +11,28 @@ pub mod bridge; pub mod columns; pub mod trace; -pub struct ProgramAir { - pub program: Vec>, +pub struct ProgramAir { + pub program: Vec>, } -impl ProgramAir { +pub struct ProgramChip { + pub air: ProgramAir, + execution_frequencies: Vec, +} + +impl ProgramChip { pub fn new(mut program: Vec>) -> Self { - // in order to make program length a power of 2, - // add instructions that jump to themselves - // so that any program that tries to jump to instructions that shouldn't exist - // will enter an infinite loop (so their termination cannot be proven) while !program.len().is_power_of_two() { - // op_c, as_c never matter in JAL - // op_a doesn't matter here (random address to write garbage to) - // op_b is the offset, needs to be 0 so we jump to self - // as_b should be nonzero so we don't write to immediate (may be unsupported) program.push(Instruction::from_isize(FAIL, 0, 0, 0, 0, 0)); } + Self { + execution_frequencies: vec![0; program.len()], + air: ProgramAir { program }, + } + } - Self { program } + pub fn get_instruction(&mut self, pc: usize) -> Instruction { + self.execution_frequencies[pc] += 1; + self.air.program[pc] } } diff --git a/vm/src/program/tests/mod.rs b/vm/src/program/tests/mod.rs index e17fe15675..28fbe0c7fa 100644 --- a/vm/src/program/tests/mod.rs +++ b/vm/src/program/tests/mod.rs @@ -9,8 +9,7 @@ use crate::cpu::{CpuAir, CpuOptions, READ_INSTRUCTION_BUS}; use crate::cpu::{trace::Instruction, OpCode::*}; use crate::program::columns::ProgramPreprocessedCols; - -use super::ProgramAir; +use crate::program::ProgramChip; #[test] fn test_flatten_fromslice_roundtrip() { @@ -33,8 +32,8 @@ fn interaction_test(is_field_arithmetic_enabled: bool, program: Vec ProgramAir { - pub fn generate_trace( - &self, - execution: &ProgramExecution, - ) -> RowMajorMatrix { - let mut frequencies = execution.execution_frequencies.clone(); - while frequencies.len() != self.program.len() { - frequencies.push(F::zero()); - } - RowMajorMatrix::new_col(frequencies) +impl ProgramChip { + pub fn generate_trace(&self) -> RowMajorMatrix { + RowMajorMatrix::new_col( + self.execution_frequencies + .iter() + .map(|x| F::from_canonical_usize(*x)) + .collect::>(), + ) } } diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index b0223c334a..479f3559bb 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use afs_chips::range_gate::RangeCheckerGateChip; use afs_stark_backend::rap::AnyRap; -use p3_field::{PrimeField32, PrimeField64}; -use p3_matrix::{dense::DenseMatrix, Matrix}; +use p3_field::PrimeField32; +use p3_matrix::dense::DenseMatrix; use p3_uni_stark::{StarkGenericConfig, Val}; use p3_util::log2_strict_usize; @@ -12,42 +12,26 @@ use crate::{ trace::{ExecutionError, Instruction}, CpuAir, RANGE_CHECKER_BUS, }, - field_arithmetic::FieldArithmeticAir, - memory::{offline_checker::OfflineChecker, MemoryAccess}, - program::ProgramAir, + memory::offline_checker::MemoryChip, + program::ProgramChip, }; use self::config::{VmConfig, VmParamsConfig}; pub mod config; -pub struct VirtualMachine -where - Val: PrimeField64, -{ +pub struct VirtualMachine { pub config: VmParamsConfig, pub cpu_air: CpuAir, - pub program_air: ProgramAir>, - pub memory_air: OfflineChecker, - pub field_arithmetic_air: FieldArithmeticAir, + pub program_chip: ProgramChip, + pub memory_chip: MemoryChip, + pub field_arithmetic_chip: FieldArithmeticChip, pub range_checker: Arc, - - pub cpu_trace: DenseMatrix>, - pub program_trace: DenseMatrix>, - pub memory_trace: DenseMatrix>, - pub field_arithmetic_trace: DenseMatrix>, - pub range_trace: DenseMatrix>, } -impl VirtualMachine -where - Val: PrimeField32, -{ - pub fn new( - config: VmConfig, - program: Vec>>, - ) -> Result { +impl VirtualMachine { + pub fn new(config: VmConfig, program: Vec>) -> Result { let config = config.vm; let decomp = config.decomp; let limb_bits = config.limb_bits; @@ -55,84 +39,21 @@ where let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, 1 << decomp)); let cpu_air = CpuAir::new(config.cpu_options()); - let program_air = ProgramAir::new(program.clone()); - let memory_air = OfflineChecker::new(WORD_SIZE, limb_bits, limb_bits, limb_bits, decomp); - let field_arithmetic_air = FieldArithmeticAir::new(); - - let execution = cpu_air.generate_program_execution(program_air.program.clone())?; - let program_trace = program_air.generate_trace(&execution); - - let ops = execution - .memory_accesses - .iter() - .map(|access| MemoryAccess { - address: access.address, - op_type: access.op_type, - address_space: access.address_space, - timestamp: access.timestamp, - data: access.data.to_vec(), - }) - .collect::>(); - let memory_trace_degree = execution.memory_accesses.len().next_power_of_two(); - let memory_trace = - memory_air.generate_trace(ops, range_checker.clone(), memory_trace_degree); - - let range_trace: DenseMatrix> = range_checker.generate_trace(); - - let field_arithmetic_trace = field_arithmetic_air.generate_trace(&execution); + let program_chip = ProgramChip::new(program.clone()); + let memory_chip = MemoryChip::new(limb_bits, limb_bits, limb_bits, decomp); + let field_arithmetic_chip = FieldArithmeticChip::new(); Ok(Self { config, cpu_air, - program_air, - memory_air, - field_arithmetic_air, + program_chip, + memory_chip, + field_arithmetic_chip, range_checker, - cpu_trace: execution.trace(config.cpu_options()), - program_trace, - memory_trace, - field_arithmetic_trace, - range_trace, }) } - pub fn chips(&self) -> Vec<&dyn AnyRap> { - if self.config.field_arithmetic_enabled { - vec![ - &self.cpu_air, - &self.program_air, - &self.memory_air, - &self.field_arithmetic_air, - &self.range_checker.air, - ] - } else { - vec![ - &self.cpu_air, - &self.program_air, - &self.memory_air, - &self.range_checker.air, - ] - } - } - - pub fn traces(&self) -> Vec>> { - if self.config.field_arithmetic_enabled { - vec![ - self.cpu_trace.clone(), - self.program_trace.clone(), - self.memory_trace.clone(), - self.field_arithmetic_trace.clone(), - self.range_trace.clone(), - ] - } else { - vec![ - self.cpu_trace.clone(), - self.program_trace.clone(), - self.memory_trace.clone(), - self.range_trace.clone(), - ] - } - } + pub fn generate_traces(&self) -> Vec>> {} /*fn max_trace_heights(&self) -> Vec { let max_operations = self.config.max_operations; @@ -158,3 +79,27 @@ where log2_strict_usize(checker_trace_degree) } } + +impl VirtualMachine> +where + Val: PrimeField32, +{ + pub fn chips(&self) -> Vec<&dyn AnyRap> { + if self.config.field_arithmetic_enabled { + vec![ + &self.cpu_air, + &self.program_chip.air, + &self.memory_chip.air, + &self.field_arithmetic_chip.air, + &self.range_checker.air, + ] + } else { + vec![ + &self.cpu_air, + &self.program_chip.air, + &self.memory_chip.air, + &self.range_checker.air, + ] + } + } +} From 994eeb305efbcbb17a20efa77d10b7293fa5f818 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 11:47:02 -0400 Subject: [PATCH 2/9] Progress (currently debugging CPU tests) --- vm/bin/src/commands/keygen.rs | 10 +- vm/bin/src/commands/prove.rs | 15 +- vm/bin/src/commands/verify.rs | 13 +- vm/src/cpu/columns.rs | 2 +- vm/src/cpu/tests/mod.rs | 267 ++++++++++++----------- vm/src/cpu/trace.rs | 279 +++++++------------------ vm/src/field_arithmetic/mod.rs | 88 ++++++-- vm/src/field_arithmetic/tests.rs | 69 ++---- vm/src/field_arithmetic/trace.rs | 23 +- vm/src/memory/mod.rs | 2 +- vm/src/memory/offline_checker/mod.rs | 22 +- vm/src/memory/offline_checker/trace.rs | 6 +- vm/src/memory/tests.rs | 85 ++++---- vm/src/program/mod.rs | 2 +- vm/src/program/tests/mod.rs | 41 ++-- vm/src/vm/mod.rs | 74 ++++--- vm/tests/integration_test.rs | 10 +- 17 files changed, 483 insertions(+), 525 deletions(-) diff --git a/vm/bin/src/commands/keygen.rs b/vm/bin/src/commands/keygen.rs index 01d2ead7b5..4a883f556f 100644 --- a/vm/bin/src/commands/keygen.rs +++ b/vm/bin/src/commands/keygen.rs @@ -12,7 +12,7 @@ use clap::Parser; use color_eyre::eyre::Result; use itertools::Itertools; use p3_matrix::Matrix; -use stark_vm::vm::{config::VmConfig, VirtualMachine}; +use stark_vm::vm::{config::VmConfig, get_chips, VirtualMachine}; use crate::asm::parse_asm_file; @@ -52,12 +52,12 @@ impl KeygenCommand { fn execute_helper(self, config: VmConfig) -> Result<()> { let instructions = parse_asm_file(Path::new(&self.asm_file_path.clone()))?; - let vm = VirtualMachine::::new(config, instructions)?; - let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()); + let mut vm = VirtualMachine::::new(config, instructions); + let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()?); let mut keygen_builder = engine.keygen_builder(); - let chips = vm.chips(); - let traces = vm.traces(); + let traces = vm.traces()?; + let chips = get_chips(&vm); for (chip, trace) in chips.into_iter().zip_eq(traces) { keygen_builder.add_air(chip, trace.height(), 0); diff --git a/vm/bin/src/commands/prove.rs b/vm/bin/src/commands/prove.rs index 1b1ba619ee..842cefc8e7 100644 --- a/vm/bin/src/commands/prove.rs +++ b/vm/bin/src/commands/prove.rs @@ -9,7 +9,7 @@ use afs_test_utils::{ }; use clap::Parser; use color_eyre::eyre::Result; -use stark_vm::vm::{config::VmConfig, VirtualMachine}; +use stark_vm::vm::{config::VmConfig, get_chips, VirtualMachine}; use crate::{ asm::parse_asm_file, @@ -54,9 +54,9 @@ impl ProveCommand { pub fn execute_helper(&self, config: VmConfig) -> Result<()> { println!("Proving program: {}", self.asm_file_path); let instructions = parse_asm_file(Path::new(&self.asm_file_path.clone()))?; - let vm = VirtualMachine::::new(config, instructions)?; + let mut vm = VirtualMachine::::new(config, instructions); - let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()); + let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()?); let encoded_pk = read_from_path(&Path::new(&self.keys_folder.clone()).join("partial.pk"))?; let partial_pk: MultiStarkPartialProvingKey = bincode::deserialize(&encoded_pk)?; @@ -66,19 +66,22 @@ impl ProveCommand { let prover = engine.prover(); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - for trace in vm.traces() { + for trace in vm.traces()? { trace_builder.load_trace(trace); } trace_builder.commit_current(); - let main_trace_data = trace_builder.view(&partial_vk, vm.chips()); + let chips = get_chips(&vm); + let num_chips = chips.len(); + + let main_trace_data = trace_builder.view(&partial_vk, chips); let mut challenger = engine.new_challenger(); let proof = prover.prove( &mut challenger, &partial_pk, main_trace_data, - &vec![vec![]; vm.chips().len()], + &vec![vec![]; num_chips], ); let encoded_proof: Vec = bincode::serialize(&proof).unwrap(); diff --git a/vm/bin/src/commands/verify.rs b/vm/bin/src/commands/verify.rs index 235b47c799..d71ab01cf9 100644 --- a/vm/bin/src/commands/verify.rs +++ b/vm/bin/src/commands/verify.rs @@ -7,7 +7,7 @@ use afs_test_utils::{ }; use clap::Parser; use color_eyre::eyre::Result; -use stark_vm::vm::{config::VmConfig, VirtualMachine}; +use stark_vm::vm::{config::VmConfig, get_chips, VirtualMachine}; use crate::{ asm::parse_asm_file, @@ -61,7 +61,7 @@ impl VerifyCommand { pub fn execute_helper(&self, config: VmConfig) -> Result<()> { println!("Verifying proof file: {}", self.proof_file); let instructions = parse_asm_file(Path::new(&self.asm_file_path))?; - let vm = VirtualMachine::::new(config, instructions)?; + let mut vm = VirtualMachine::::new(config, instructions); let encoded_vk = read_from_path(&Path::new(&self.keys_folder).join("partial.vk"))?; let partial_vk: MultiStarkPartialVerifyingKey = bincode::deserialize(&encoded_vk)?; @@ -69,16 +69,19 @@ impl VerifyCommand { let encoded_proof = read_from_path(Path::new(&self.proof_file))?; let proof: Proof = bincode::deserialize(&encoded_proof)?; - let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()); + let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()?); + + let chips = get_chips(&vm); + let num_chips = chips.len(); let mut challenger = engine.new_challenger(); let verifier = engine.verifier(); let result = verifier.verify( &mut challenger, partial_vk, - vm.chips(), + chips, proof, - &vec![vec![]; vm.chips().len()], + &vec![vec![]; num_chips], ); if result.is_err() { diff --git a/vm/src/cpu/columns.rs b/vm/src/cpu/columns.rs index fd9bc54589..2e31ed5e69 100644 --- a/vm/src/cpu/columns.rs +++ b/vm/src/cpu/columns.rs @@ -50,7 +50,7 @@ impl CpuIoCols { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct MemoryAccessCols { pub enabled: T, diff --git a/vm/src/cpu/tests/mod.rs b/vm/src/cpu/tests/mod.rs index b830a87a94..ec8e9d62b6 100644 --- a/vm/src/cpu/tests/mod.rs +++ b/vm/src/cpu/tests/mod.rs @@ -3,22 +3,41 @@ use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; use afs_test_utils::interaction::dummy_interaction_air::DummyInteractionAir; use p3_baby_bear::BabyBear; -use p3_field::{AbstractField, PrimeField64}; -use p3_matrix::dense::RowMajorMatrix; +use p3_field::{AbstractField, PrimeField32, PrimeField64}; +use p3_matrix::dense::{DenseMatrix, RowMajorMatrix}; +use p3_matrix::Matrix; use crate::cpu::columns::{CpuCols, CpuIoCols}; use crate::cpu::{CpuAir, CpuOptions}; -use crate::memory::OpType; +use crate::field_arithmetic::ArithmeticOperation; +use crate::memory::{MemoryAccess, OpType}; +use crate::vm::config::{VmConfig, VmParamsConfig}; +use crate::vm::VirtualMachine; use super::columns::MemoryAccessCols; -use super::trace::{isize_to_field, ProgramExecution}; +use super::trace::isize_to_field; use super::{decompose, ARITHMETIC_BUS, MEMORY_BUS, READ_INSTRUCTION_BUS}; -use super::{ - trace::{ArithmeticOperation, Instruction, MemoryAccess}, - OpCode::*, -}; +use super::{trace::Instruction, OpCode::*}; const TEST_WORD_SIZE: usize = 4; +const LIMB_BITS: usize = 8; +const DECOMP: usize = 4; + +fn make_vm( + program: Vec>, + field_arithmetic_enabled: bool, +) -> VirtualMachine { + VirtualMachine::::new( + VmConfig { + vm: VmParamsConfig { + field_arithmetic_enabled, + limb_bits: LIMB_BITS, + decomp: DECOMP, + }, + }, + program, + ) +} impl MemoryAccess { pub fn from_isize( @@ -56,29 +75,56 @@ fn test_flatten_fromslice_roundtrip() { assert_eq!(num_cols, flattened.len()); } -fn program_execution_test( +/*fn test( + field_arithmetic_enabled: bool, + program: Vec>, + mut expected_execution: Vec, + expected_memory_log: Vec>, + expected_arithmetic_operations: Vec>, +) { + program_execution_test( + field_arithmetic_enabled, + program, + expected_execution, + expected_memory_log, + expected_arithmetic_operations, + ); + let mut expected_execution_frequencies = expected_execution.clone(); + for i in 0..expected_execution.len() { + expected_execution_frequencies[i] += 1; + } + air_test( + field_arithmetic_enabled, + program, + expected_execution_frequencies, + expected_memory_log, + expected_arithmetic_operations, + ); +}*/ + +fn execution_test( field_arithmetic_enabled: bool, program: Vec>, mut expected_execution: Vec, expected_memory_log: Vec>, expected_arithmetic_operations: Vec>, ) { - let air = CpuAir::new(CpuOptions { - field_arithmetic_enabled, - }); - let execution = air.generate_program_execution(program.clone()).unwrap(); + let mut vm = make_vm(program.clone(), field_arithmetic_enabled); + let mut trace = CpuAir::generate_trace(&mut vm).unwrap(); - assert_eq!(execution.program, program); - assert_eq!(execution.memory_accesses, expected_memory_log); - assert_eq!(execution.arithmetic_ops, expected_arithmetic_operations); + assert_eq!(vm.memory_chip.accesses, expected_memory_log); + assert_eq!( + vm.field_arithmetic_chip.operations, + expected_arithmetic_operations + ); while !expected_execution.len().is_power_of_two() { expected_execution.push(*expected_execution.last().unwrap()); } - assert_eq!(execution.trace_rows.len(), expected_execution.len()); - for (i, row) in execution.trace_rows.iter().enumerate() { - let pc = expected_execution[i]; + assert_eq!(trace.height(), expected_execution.len()); + for (i, &pc) in expected_execution.iter().enumerate() { + let cols = CpuCols::::from_slice(trace.row_mut(i), vm.config.cpu_options()); let expected_io = CpuIoCols { clock_cycle: F::from_canonical_u64(i as u64), pc: F::from_canonical_u64(pc as u64), @@ -89,16 +135,15 @@ fn program_execution_test( d: program[pc].d, e: program[pc].e, }; - assert_eq!(row.io, expected_io); + assert_eq!(cols.io, expected_io); } - let mut execution_frequency_check = execution.execution_frequencies.clone(); - for row in execution.trace_rows { - let pc = row.io.pc.as_canonical_u64() as usize; - execution_frequency_check[pc] += F::neg_one(); + let mut execution_frequency_check = vm.program_chip.execution_frequencies.clone(); + for pc in expected_execution { + execution_frequency_check[pc] -= 1; } for frequency in execution_frequency_check.iter() { - assert_eq!(*frequency, F::zero()); + assert_eq!(*frequency, 0); } } @@ -106,61 +151,59 @@ fn air_test( field_arithmetic_enabled: bool, program: Vec>, ) { - let air = CpuAir::new(CpuOptions { - field_arithmetic_enabled, - }); - let execution = air.generate_program_execution(program).unwrap(); - air_test_custom_execution::(field_arithmetic_enabled, execution); + air_test_change::(field_arithmetic_enabled, program, false, |_, _| {}); } fn air_test_change_pc( field_arithmetic_enabled: bool, program: Vec>, - change_row: usize, - change_value: usize, should_fail: bool, + change_row: usize, + new: usize, ) { - let air = CpuAir::new(CpuOptions { + air_test_change::( field_arithmetic_enabled, - }); - let mut execution = air.generate_program_execution(program).unwrap(); - - let old_value = execution.trace_rows[change_row].io.pc.as_canonical_u64() as usize; - execution.trace_rows[change_row].io.pc = BabyBear::from_canonical_usize(change_value); - - execution.execution_frequencies[old_value] -= BabyBear::one(); - execution.execution_frequencies[change_value] += BabyBear::one(); - - air_test_custom_execution_with_failure::( - field_arithmetic_enabled, - execution, + program, should_fail, + |rows, vm| { + let old = rows[change_row].io.pc.as_canonical_u64() as usize; + rows[change_row].io.pc = BabyBear::from_canonical_usize(new); + vm.program_chip.execution_frequencies[new] += 1; + vm.program_chip.execution_frequencies[old] -= 1; + }, ); } -fn air_test_custom_execution( +fn air_test_change< + const WORD_SIZE: usize, + F: Fn(&mut Vec>, &mut VirtualMachine), +>( field_arithmetic_enabled: bool, - execution: ProgramExecution, -) { - air_test_custom_execution_with_failure(field_arithmetic_enabled, execution, false); -} - -fn air_test_custom_execution_with_failure( - field_arithmetic_enabled: bool, - execution: ProgramExecution, + program: Vec>, should_fail: bool, + change: F, ) { - let options = CpuOptions { - field_arithmetic_enabled, - }; - let air = CpuAir::::new(options); - let trace = execution.trace(options); + let mut vm = make_vm(program.clone(), field_arithmetic_enabled); + let mut trace = CpuAir::generate_trace(&mut vm).unwrap(); + let mut rows = vec![]; + for i in 0..trace.height() { + rows.push(CpuCols::::from_slice( + trace.row_mut(i), + vm.config.cpu_options(), + )); + } + change(&mut rows, &mut vm); + let mut flattened = vec![]; + for row in rows { + flattened.extend(row.flatten(vm.config.cpu_options())); + } + let trace = DenseMatrix::new(flattened, trace.width()); let program_air = DummyInteractionAir::new(7, false, READ_INSTRUCTION_BUS); let mut program_rows = vec![]; - for (pc, instruction) in execution.program.iter().enumerate() { + for (pc, instruction) in program.iter().enumerate() { program_rows.extend(vec![ - execution.execution_frequencies[pc], + BabyBear::from_canonical_usize(vm.program_chip.execution_frequencies[pc]), BabyBear::from_canonical_usize(pc), BabyBear::from_canonical_usize(instruction.opcode as usize), instruction.op_a, @@ -177,7 +220,7 @@ fn air_test_custom_execution_with_failure( let memory_air = DummyInteractionAir::new(5, false, MEMORY_BUS); let mut memory_rows = vec![]; - for memory_access in execution.memory_accesses.iter() { + for memory_access in vm.memory_chip.accesses.iter() { memory_rows.extend(vec![ BabyBear::one(), BabyBear::from_canonical_usize(memory_access.timestamp), @@ -194,7 +237,7 @@ fn air_test_custom_execution_with_failure( let arithmetic_air = DummyInteractionAir::new(4, false, ARITHMETIC_BUS); let mut arithmetic_rows = vec![]; - for arithmetic_op in execution.arithmetic_ops.iter() { + for arithmetic_op in vm.field_arithmetic_chip.operations.iter() { arithmetic_rows.extend(vec![ BabyBear::one(), BabyBear::from_canonical_usize(arithmetic_op.opcode as usize), @@ -210,12 +253,12 @@ fn air_test_custom_execution_with_failure( let test_result = if field_arithmetic_enabled { run_simple_test_no_pis( - vec![&air, &program_air, &memory_air, &arithmetic_air], + vec![&vm.cpu_air, &program_air, &memory_air, &arithmetic_air], vec![trace, program_trace, memory_trace, arithmetic_trace], ) } else { run_simple_test_no_pis( - vec![&air, &program_air, &memory_air], + vec![&vm.cpu_air, &program_air, &memory_air], vec![trace, program_trace, memory_trace], ) }; @@ -288,7 +331,7 @@ fn test_cpu_1() { )); } - program_execution_test::( + execution_test::( true, program.clone(), expected_execution, @@ -330,7 +373,7 @@ fn test_cpu_without_field_arithmetic() { MemoryAccess::from_isize(6, OpType::Read, 1, 0, 5), ]; - program_execution_test::( + execution_test::( field_arithmetic_enabled, program.clone(), expected_execution, @@ -363,7 +406,7 @@ fn test_cpu_negative_wrong_pc() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change_pc::(true, program, 2, 3, true); + air_test_change_pc::(true, program, true, 2, 3); } #[test] @@ -382,7 +425,7 @@ fn test_cpu_negative_wrong_pc_check() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change_pc::(true, program, 2, 2, false); + air_test_change_pc::(true, program, false, 2, 2); } #[test] @@ -396,15 +439,11 @@ fn test_cpu_negative_hasnt_terminated() { // terminate Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - execution.trace_rows.remove(execution.trace_rows.len() - 1); - execution.execution_frequencies[1] = AbstractField::zero(); - air_test_custom_execution::(true, execution); + air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { + rows.remove(rows.len() - 1); + vm.program_chip.execution_frequencies[1] = 0; + }); } #[test] @@ -417,32 +456,26 @@ fn test_cpu_negative_secret_write() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - - let is_zero_air = IsZeroAir; - let mut is_zero_trace = is_zero_air - .generate_trace(vec![AbstractField::one()]) - .clone(); - let is_zero_aux = is_zero_trace.row_mut(0)[2]; - - execution.trace_rows[0].aux.accesses[2] = MemoryAccessCols { - enabled: AbstractField::one(), - address_space: AbstractField::one(), - is_immediate: AbstractField::zero(), - is_zero_aux, - address: AbstractField::zero(), - data: decompose(AbstractField::from_canonical_usize(115)), - }; - - execution - .memory_accesses - .push(MemoryAccess::from_isize(0, OpType::Write, 1, 0, 115)); + air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { + let is_zero_air = IsZeroAir; + let mut is_zero_trace = is_zero_air + .generate_trace(vec![AbstractField::one()]) + .clone(); + let is_zero_aux = is_zero_trace.row_mut(0)[2]; + + rows[0].aux.accesses[2] = MemoryAccessCols { + enabled: AbstractField::one(), + address_space: AbstractField::one(), + is_immediate: AbstractField::zero(), + is_zero_aux, + address: AbstractField::zero(), + data: decompose(AbstractField::from_canonical_usize(115)), + }; - air_test_custom_execution::(true, execution); + vm.memory_chip + .accesses + .push(MemoryAccess::from_isize(0, OpType::Write, 1, 0, 115)); + }); } #[test] @@ -455,17 +488,10 @@ fn test_cpu_negative_disable_write() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - - execution.trace_rows[0].aux.accesses[2].enabled = AbstractField::zero(); - - execution.memory_accesses.remove(0); - - air_test_custom_execution::(true, execution); + air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { + rows[0].aux.accesses[2].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(0); + }); } #[test] @@ -478,15 +504,8 @@ fn test_cpu_negative_disable_read() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - - execution.trace_rows[0].aux.accesses[0].enabled = AbstractField::zero(); - - execution.memory_accesses.remove(0); - - air_test_custom_execution::(true, execution); + air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { + rows[0].aux.accesses[0].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(0); + }); } diff --git a/vm/src/cpu/trace.rs b/vm/src/cpu/trace.rs index 521e8b748e..36e9fa4811 100644 --- a/vm/src/cpu/trace.rs +++ b/vm/src/cpu/trace.rs @@ -1,24 +1,19 @@ -use std::{ - array::from_fn, - collections::{BTreeMap, HashMap}, - error::Error, - fmt::Display, -}; +use std::{collections::BTreeMap, error::Error, fmt::Display}; -use p3_field::{Field, PrimeField64}; +use p3_field::{Field, PrimeField32, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; use afs_chips::{ is_equal_vec::IsEqualVecAir, is_zero::IsZeroAir, sub_chip::LocalTraceInstructions, }; -use crate::{field_arithmetic::FieldArithmeticAir, memory::OpType, vm::VirtualMachine}; +use crate::{cpu::MAX_ACCESSES_PER_CYCLE, vm::VirtualMachine}; use super::{ columns::{CpuAuxCols, CpuCols, CpuIoCols, MemoryAccessCols}, - compose, decompose, CpuAir, CpuOptions, + compose, decompose, CpuAir, OpCode::{self, *}, - INST_WIDTH, MAX_READS_PER_CYCLE, MAX_WRITES_PER_CYCLE, + INST_WIDTH, MAX_READS_PER_CYCLE, }; #[derive(Copy, Clone, Debug, PartialEq, Eq, derive_new::new)] @@ -31,34 +26,14 @@ pub struct Instruction { pub e: F, } -impl ArithmeticOperation { - pub fn from_isize(opcode: OpCode, operand1: isize, operand2: isize, result: isize) -> Self { - Self { - opcode, - operand1: isize_to_field::(operand1), - operand2: isize_to_field::(operand2), - result: isize_to_field::(result), - } - } - - pub fn to_vec(&self) -> Vec { - vec![ - F::from_canonical_usize(self.opcode as usize), - self.operand1, - self.operand2, - self.result, - ] - } -} - -pub fn isize_to_field(value: isize) -> F { +pub fn isize_to_field(value: isize) -> F { if value < 0 { return F::neg_one() * F::from_canonical_usize(value.unsigned_abs()); } F::from_canonical_usize(value as usize) } -impl Instruction { +impl Instruction { pub fn from_isize( opcode: OpCode, op_a: isize, @@ -78,48 +53,30 @@ impl Instruction { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct MemoryAccess { - pub timestamp: usize, - pub op_type: OpType, - pub address_space: F, - pub address: F, - pub data: [F; WORD_SIZE], +fn disabled_memory_cols() -> MemoryAccessCols +{ + memory_access_to_cols(false, F::one(), F::zero(), [F::zero(); WORD_SIZE]) } fn memory_access_to_cols( - access: Option<&MemoryAccess>, + enabled: bool, + address_space: F, + address: F, + data: [F; WORD_SIZE], ) -> MemoryAccessCols { - let (enabled, address_space, address, value) = match access { - Some(&MemoryAccess { - address_space, - address, - data, - .. - }) => (F::one(), address_space, address, data), - None => (F::zero(), F::one(), F::zero(), [F::zero(); WORD_SIZE]), - }; let is_zero_cols = LocalTraceInstructions::generate_trace_row(&IsZeroAir {}, address_space); let is_immediate = is_zero_cols.io.is_zero; let is_zero_aux = is_zero_cols.inv; MemoryAccessCols { - enabled, + enabled: F::from_bool(enabled), address_space, is_immediate, is_zero_aux, address, - data: value, + data, } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct ArithmeticOperation { - pub opcode: OpCode, - pub operand1: F, - pub operand2: F, - pub result: F, -} - #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct FieldExtensionOperation { pub opcode: OpCode, @@ -138,88 +95,6 @@ impl FieldExtensionOperation { } } -struct Memory { - data: HashMap>, - log: Vec>, - clock_cycle: usize, - reads_this_cycle: Vec>, - writes_this_cycle: Vec>, -} - -impl Memory { - fn new() -> Self { - let mut data = HashMap::new(); - data.insert(F::one(), HashMap::new()); - data.insert(F::two(), HashMap::new()); - - Self { - data, - log: vec![], - clock_cycle: 0, - reads_this_cycle: vec![], - writes_this_cycle: vec![], - } - } - - fn read(&mut self, address_space: F, address: F) -> [F; WORD_SIZE] { - let data = if address_space == F::zero() { - decompose::(address) - } else { - *self.data[&address_space] - .get(&address) - .unwrap_or(&[F::zero(); WORD_SIZE]) - }; - let read = MemoryAccess { - timestamp: ((MAX_READS_PER_CYCLE + MAX_WRITES_PER_CYCLE) * self.clock_cycle) - + self.reads_this_cycle.len(), - op_type: OpType::Read, - address_space, - address, - data, - }; - if read.address_space != F::zero() { - self.log.push(read); - } - self.reads_this_cycle.push(read); - data - } - - fn write(&mut self, address_space: F, address: F, data: [F; WORD_SIZE]) { - if address_space == F::zero() { - panic!("Attempted to write to address space 0"); - } else { - let write = MemoryAccess { - timestamp: ((MAX_READS_PER_CYCLE + MAX_WRITES_PER_CYCLE) * self.clock_cycle) - + MAX_READS_PER_CYCLE - + self.writes_this_cycle.len(), - op_type: OpType::Write, - address_space, - address, - data, - }; - self.log.push(write); - self.writes_this_cycle.push(write); - - self.data - .get_mut(&address_space) - .unwrap() - .insert(address, data); - } - } - - fn complete_clock_cycle( - &mut self, - ) -> ( - Vec>, - Vec>, - ) { - self.clock_cycle += 1; - let reads = std::mem::take(&mut self.reads_this_cycle); - let writes = std::mem::take(&mut self.writes_this_cycle); - (reads, writes) - } -} - #[derive(Debug)] pub enum ExecutionError { Fail(usize), @@ -244,24 +119,18 @@ impl Display for ExecutionError { impl Error for ExecutionError {} impl CpuAir { - pub fn generate_program_execution( - &self, - vm: VirtualMachine, - ) -> RowMajorMatrix { + pub fn generate_trace( + vm: &mut VirtualMachine, + ) -> Result, ExecutionError> { let mut rows = vec![]; - let mut execution_frequencies = vec![F::zero(); program.len()]; - let mut arithmetic_operations = vec![]; let mut clock_cycle: usize = 0; let mut pc = F::zero(); - let mut memory = Memory::new(); - loop { let pc_usize = pc.as_canonical_u64() as usize; - execution_frequencies[pc_usize] += F::one(); - let instruction = program[pc_usize]; + let instruction = vm.program_chip.get_instruction(pc_usize); let opcode = instruction.opcode; let a = instruction.op_a; let b = instruction.op_b; @@ -282,36 +151,70 @@ impl CpuAir { let mut next_pc = pc + F::one(); + let mut accesses = [disabled_memory_cols(); MAX_ACCESSES_PER_CYCLE]; + let mut read_index = 0; + let mut write_index = MAX_READS_PER_CYCLE; + + macro_rules! read { + ($address_space: expr, $address: expr) => {{ + assert!(read_index < MAX_READS_PER_CYCLE); + let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + read_index; + let data = if $address_space == F::zero() { + decompose::($address) + } else { + vm.memory_chip + .read_word(timestamp, $address_space, $address) + }; + accesses[read_index] = + memory_access_to_cols(true, $address_space, $address, data); + read_index += 1; + compose(data) + }}; + } + + macro_rules! write { + ($address_space: expr, $address: expr, $data: expr) => {{ + assert!(write_index < MAX_ACCESSES_PER_CYCLE); + let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + write_index; + let word = decompose($data); + vm.memory_chip + .write_word(timestamp, $address_space, $address, word); + accesses[write_index] = + memory_access_to_cols(true, $address_space, $address, word); + write_index += 1; + }}; + } + match opcode { // d[a] <- e[d[c] + b] LOADW => { - let base_pointer = compose(memory.read(d, c)); - let value = memory.read(e, base_pointer + b); - memory.write(d, a, value); + let base_pointer = read!(d, c); + let value = read!(e, base_pointer + b); + write!(d, a, value); } // e[d[c] + b] <- d[a] STOREW => { - let base_pointer = compose(memory.read(d, c)); - let value = memory.read(d, a); - memory.write(e, base_pointer + b, value); + let base_pointer = read!(d, c); + let value = read!(d, a); + write!(e, base_pointer + b, value); } // d[a] <- pc + INST_WIDTH, pc <- pc + b JAL => { - memory.write(d, a, decompose(pc + F::from_canonical_usize(INST_WIDTH))); + write!(d, a, pc + F::from_canonical_usize(INST_WIDTH)); next_pc = pc + b; } // If d[a] = e[b], pc <- pc + c BEQ => { - let left = memory.read(d, a); - let right = memory.read(e, b); + let left = read!(d, a); + let right = read!(e, b); if left == right { next_pc = pc + c; } } // If d[a] != e[b], pc <- pc + c BNE => { - let left = memory.read(d, a); - let right = memory.read(e, b); + let left = read!(d, a); + let right = read!(e, b); if left != right { next_pc = pc + c; } @@ -320,50 +223,31 @@ impl CpuAir { next_pc = pc; } opcode @ (FADD | FSUB | FMUL | FDIV) => { - if self.options.field_arithmetic_enabled { + if vm.config.cpu_options().field_arithmetic_enabled { // read from d[b] and e[c] - let operand1 = compose(memory.read(d, b)); - let operand2 = compose(memory.read(e, c)); + let operand1 = read!(d, b); + let operand2 = read!(e, c); // write to d[a] - let result = - FieldArithmeticAir::solve(opcode, (operand1, operand2)).unwrap(); - memory.write(d, a, decompose(result)); - - arithmetic_operations.push(ArithmeticOperation { - opcode, - operand1, - operand2, - result, - }); + let result = vm + .field_arithmetic_chip + .calculate(opcode, (operand1, operand2)); + write!(d, a, result); } else { return Err(ExecutionError::DisabledOperation(opcode)); } } FAIL => return Err(ExecutionError::Fail(pc_usize)), PRINTF => { - let value = memory.read(d, a); - println!("{}", compose(value)); + let value = read!(d, a); + println!("{}", value); } }; let mut operation_flags = BTreeMap::new(); - for other_opcode in self.options.enabled_instructions() { + for other_opcode in vm.config.cpu_options().enabled_instructions() { operation_flags.insert(other_opcode, F::from_bool(other_opcode == opcode)); } - // complete the clock cycle and get the read and write cols - let (reads, writes) = memory.complete_clock_cycle(); - assert!(reads.len() <= MAX_READS_PER_CYCLE); - assert!(writes.len() <= MAX_WRITES_PER_CYCLE); - - let accesses = from_fn(|i| { - memory_access_to_cols(if i < MAX_READS_PER_CYCLE { - reads.get(i) - } else { - writes.get(i - MAX_READS_PER_CYCLE) - }) - }); - let is_equal_vec_cols = LocalTraceInstructions::generate_trace_row( &IsEqualVecAir::new(WORD_SIZE), (accesses[0].data.to_vec(), accesses[1].data.to_vec()), @@ -380,7 +264,7 @@ impl CpuAir { }; let cols = CpuCols { io, aux }; - rows.push(cols); + rows.extend(cols.flatten(vm.config.cpu_options())); pc = next_pc; clock_cycle += 1; @@ -390,12 +274,9 @@ impl CpuAir { } } - Ok(ProgramExecution { - program, - execution_frequencies, - trace_rows: rows, - memory_accesses: memory.log, - arithmetic_ops: arithmetic_operations, - }) + Ok(RowMajorMatrix::new( + rows, + CpuCols::::get_width(vm.config.cpu_options()), + )) } } diff --git a/vm/src/field_arithmetic/mod.rs b/vm/src/field_arithmetic/mod.rs index 5c961cb3b8..092dac33be 100644 --- a/vm/src/field_arithmetic/mod.rs +++ b/vm/src/field_arithmetic/mod.rs @@ -1,5 +1,5 @@ -use super::cpu::trace::ArithmeticOperation; -use crate::cpu::OpCode; +use crate::cpu::{trace::isize_to_field, OpCode}; +use itertools::Itertools; use p3_field::Field; #[cfg(test)] @@ -13,6 +13,35 @@ pub mod trace; /// Field arithmetic chip. /// /// Carries information about opcodes (currently 6..=9) and bus index (currently 2). + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct ArithmeticOperation { + pub opcode: OpCode, + pub operand1: F, + pub operand2: F, + pub result: F, +} + +impl ArithmeticOperation { + pub fn from_isize(opcode: OpCode, operand1: isize, operand2: isize, result: isize) -> Self { + Self { + opcode, + operand1: isize_to_field::(operand1), + operand2: isize_to_field::(operand2), + result: isize_to_field::(result), + } + } + + pub fn to_vec(&self) -> Vec { + vec![ + F::from_canonical_usize(self.opcode as usize), + self.operand1, + self.operand2, + self.result, + ] + } +} + #[derive(Default, Clone, Copy)] pub struct FieldArithmeticAir {} @@ -20,14 +49,10 @@ impl FieldArithmeticAir { pub const BASE_OP: u8 = OpCode::FADD as u8; pub const BUS_INDEX: usize = 2; - pub fn new() -> Self { - Self {} - } - /// Evaluates given opcode using given operands. /// /// Returns None for non-arithmetic operations. - pub fn solve(op: OpCode, operands: (T, T)) -> Option { + fn solve(op: OpCode, operands: (T, T)) -> Option { match op { OpCode::FADD => Some(operands.0 + operands.1), OpCode::FSUB => Some(operands.0 - operands.1), @@ -44,20 +69,41 @@ impl FieldArithmeticAir { .filter_map(|(op, operand)| Self::solve::(*op, *operand)) .collect() } +} - /// Converts vectorized opcodes and operands into vectorized ArithmeticOperations. - pub fn request( - ops: Vec, - operands: Vec<(T, T)>, - ) -> Vec> { - ops.iter() - .zip(operands.iter()) - .map(|(op, operand)| ArithmeticOperation { - opcode: *op, - operand1: operand.0, - operand2: operand.1, - result: Self::solve::(*op, *operand).unwrap(), - }) - .collect() +pub struct FieldArithmeticChip { + pub air: FieldArithmeticAir, + pub operations: Vec>, +} + +impl FieldArithmeticChip { + pub fn new() -> Self { + Self { + air: FieldArithmeticAir {}, + operations: vec![], + } + } + + pub fn calculate(&mut self, op: OpCode, operands: (F, F)) -> F { + let result = FieldArithmeticAir::solve::(op, operands).unwrap(); + self.operations.push(ArithmeticOperation { + opcode: op, + operand1: operands.0, + operand2: operands.1, + result, + }); + result + } + + pub fn request(&mut self, ops: Vec, operands_vec: Vec<(F, F)>) { + for (op, operands) in ops.iter().zip_eq(operands_vec.iter()) { + self.calculate(*op, *operands); + } + } +} + +impl Default for FieldArithmeticChip { + fn default() -> Self { + Self::new() } } diff --git a/vm/src/field_arithmetic/tests.rs b/vm/src/field_arithmetic/tests.rs index dfdfeb7a62..ddcc538081 100644 --- a/vm/src/field_arithmetic/tests.rs +++ b/vm/src/field_arithmetic/tests.rs @@ -1,7 +1,7 @@ use super::columns::FieldArithmeticCols; use super::columns::FieldArithmeticIOCols; use super::FieldArithmeticAir; -use crate::cpu::trace::{ArithmeticOperation, ProgramExecution}; +use super::FieldArithmeticChip; use crate::cpu::OpCode; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::verifier::VerificationError; @@ -14,7 +14,7 @@ use p3_matrix::dense::RowMajorMatrix; use rand::Rng; /// Function for testing that generates a random program consisting only of field arithmetic operations. -fn generate_arith_program(len_ops: usize) -> ProgramExecution<1, BabyBear> { +fn generate_arith_program(chip: &mut FieldArithmeticChip, len_ops: usize) { let mut rng = create_seeded_rng(); let ops = (0..len_ops) .map(|_| OpCode::from_u8(rng.gen_range(6..=9)).unwrap()) @@ -27,15 +27,7 @@ fn generate_arith_program(len_ops: usize) -> ProgramExecution<1, BabyBear> { ) }) .collect(); - let arith_ops = FieldArithmeticAir::request(ops, operands); - - ProgramExecution { - program: vec![], - trace_rows: vec![], - execution_frequencies: vec![], - memory_accesses: vec![], - arithmetic_ops: arith_ops, - } + chip.request(ops, operands); } #[test] @@ -43,12 +35,12 @@ fn au_air_test() { let mut rng = create_seeded_rng(); let len_ops: usize = 3; let correct_height = len_ops.next_power_of_two(); - let prog = generate_arith_program(len_ops); - let au_air = FieldArithmeticAir::new(); + let mut chip = FieldArithmeticChip::new(); + generate_arith_program(&mut chip, len_ops); let empty_dummy_row = FieldArithmeticCols::::blank_row().io.flatten(); let dummy_trace = RowMajorMatrix::new( - prog.arithmetic_ops + chip.operations .clone() .iter() .flat_map(|op| { @@ -62,7 +54,7 @@ fn au_air_test() { FieldArithmeticIOCols::::get_width(), ); - let mut au_trace = au_air.generate_trace(&prog); + let mut au_trace = chip.generate_trace(); let page_requester = DummyInteractionAir::new( FieldArithmeticIOCols::::get_width() - 1, @@ -72,13 +64,13 @@ fn au_air_test() { // positive test run_simple_test_no_pis( - vec![&au_air, &page_requester], + vec![&chip.air, &page_requester], vec![au_trace.clone(), dummy_trace.clone()], ) .expect("Verification failed"); // negative test pranking each IO value - for height in 0..(prog.arithmetic_ops.len()) { + for height in 0..(chip.operations.len()) { for width in 0..FieldArithmeticIOCols::::get_width() { let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); au_trace.row_mut(height)[width] = prank_value; @@ -90,28 +82,21 @@ fn au_air_test() { }); assert_eq!( run_simple_test_no_pis( - vec![&au_air, &page_requester], + vec![&chip.air, &page_requester], vec![au_trace.clone(), dummy_trace.clone()], ), Err(VerificationError::OodEvaluationMismatch), "Expected constraint to fail" ) } +} - let zero_div_zero_prog = ProgramExecution::<1, BabyBear> { - program: vec![], - trace_rows: vec![], - execution_frequencies: vec![], - memory_accesses: vec![], - arithmetic_ops: vec![ArithmeticOperation { - opcode: OpCode::FDIV, - operand1: BabyBear::zero(), - operand2: BabyBear::one(), - result: BabyBear::zero(), - }], - }; +#[test] +fn au_air_zero_div_zero() { + let mut chip = FieldArithmeticChip::new(); + chip.calculate(OpCode::FDIV, (BabyBear::zero(), BabyBear::one())); - let mut au_trace = au_air.generate_trace(&zero_div_zero_prog); + let mut au_trace = chip.generate_trace(); au_trace.row_mut(0)[3] = BabyBear::zero(); let page_requester = DummyInteractionAir::new( FieldArithmeticIOCols::::get_width() - 1, @@ -133,7 +118,7 @@ fn au_air_test() { }); assert_eq!( run_simple_test_no_pis( - vec![&au_air, &page_requester], + vec![&chip.air, &page_requester], vec![au_trace.clone(), dummy_trace.clone()], ), Err(VerificationError::OodEvaluationMismatch), @@ -144,21 +129,7 @@ fn au_air_test() { #[should_panic] #[test] fn au_air_test_panic() { - let au_air = FieldArithmeticAir::new(); - - let zero_div_zero_prog = ProgramExecution::<1, BabyBear> { - program: vec![], - trace_rows: vec![], - execution_frequencies: vec![], - memory_accesses: vec![], - arithmetic_ops: vec![ArithmeticOperation { - opcode: OpCode::FDIV, - operand1: BabyBear::zero(), - operand2: BabyBear::zero(), - result: BabyBear::zero(), - }], - }; - - // Should panic - au_air.generate_trace(&zero_div_zero_prog); + let mut chip = FieldArithmeticChip::new(); + // should panic + chip.calculate(OpCode::FDIV, (BabyBear::zero(), BabyBear::zero())); } diff --git a/vm/src/field_arithmetic/trace.rs b/vm/src/field_arithmetic/trace.rs index 88dfc69f30..b7e3f02870 100644 --- a/vm/src/field_arithmetic/trace.rs +++ b/vm/src/field_arithmetic/trace.rs @@ -1,8 +1,8 @@ use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; -use super::columns::{FieldArithmeticAuxCols, FieldArithmeticCols, FieldArithmeticIOCols}; -use crate::cpu::{trace::ProgramExecution, OpCode}; +use super::{columns::{FieldArithmeticAuxCols, FieldArithmeticCols, FieldArithmeticIOCols}, FieldArithmeticChip}; +use crate::cpu::OpCode; use super::FieldArithmeticAir; @@ -51,14 +51,11 @@ fn generate_cols(op: OpCode, x: T, y: T) -> FieldArithmeticCols { } } -impl FieldArithmeticAir { +impl FieldArithmeticChip { /// Generates trace for field arithmetic chip. - pub fn generate_trace( - &self, - prog_exec: &ProgramExecution, - ) -> RowMajorMatrix { - let mut trace: Vec = prog_exec - .arithmetic_ops + pub fn generate_trace(&self) -> RowMajorMatrix { + let mut trace: Vec = self + .operations .iter() .flat_map(|op| { let cols = generate_cols(op.opcode, op.operand1, op.operand2); @@ -66,17 +63,17 @@ impl FieldArithmeticAir { }) .collect(); - let empty_row = FieldArithmeticCols::::blank_row().flatten(); - let curr_height = prog_exec.arithmetic_ops.len(); + let empty_row: Vec = FieldArithmeticCols::blank_row().flatten(); + let curr_height = self.operations.len(); let correct_height = curr_height.next_power_of_two(); trace.extend( empty_row .iter() .cloned() .cycle() - .take((correct_height - curr_height) * FieldArithmeticCols::::NUM_COLS), + .take((correct_height - curr_height) * FieldArithmeticCols::::NUM_COLS), ); - RowMajorMatrix::new(trace, FieldArithmeticCols::::NUM_COLS) + RowMajorMatrix::new(trace, FieldArithmeticCols::::NUM_COLS) } } diff --git a/vm/src/memory/mod.rs b/vm/src/memory/mod.rs index 0bb36ba6c6..679bd72234 100644 --- a/vm/src/memory/mod.rs +++ b/vm/src/memory/mod.rs @@ -8,7 +8,7 @@ pub enum OpType { Write = 1, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct MemoryAccess { pub timestamp: usize, pub op_type: OpType, diff --git a/vm/src/memory/offline_checker/mod.rs b/vm/src/memory/offline_checker/mod.rs index 72079bf54a..8500613e4b 100644 --- a/vm/src/memory/offline_checker/mod.rs +++ b/vm/src/memory/offline_checker/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{array::from_fn, collections::HashMap}; use afs_chips::is_less_than_tuple::columns::IsLessThanTupleAuxCols; use p3_field::PrimeField32; @@ -37,7 +37,7 @@ impl OfflineChecker { pub struct MemoryChip { pub air: OfflineChecker, pub accesses: Vec>, - memory: HashMap<(F, F), [F; WORD_SIZE]>, + memory: HashMap<(F, F), F>, last_timestamp: Option, } @@ -60,19 +60,21 @@ impl MemoryChip { } pub fn read_word(&mut self, timestamp: usize, address_space: F, address: F) -> [F; WORD_SIZE] { + // temporary, as cpu trace generation currently works using word-addressing + let address = F::from_canonical_usize(WORD_SIZE) * address; + assert!(address_space != F::zero()); - assert!(timestamp % WORD_SIZE == 0); if let Some(last_timestamp) = self.last_timestamp { assert!(timestamp > last_timestamp); } self.last_timestamp = Some(timestamp); - let data = self.memory[&(address_space, address)]; + let data = from_fn(|i| self.memory[&(address_space, address + F::from_canonical_usize(i))]); self.accesses.push(MemoryAccess { timestamp, op_type: OpType::Read, address_space, address, - data: data.clone(), + data, }); data } @@ -84,19 +86,23 @@ impl MemoryChip { address: F, data: [F; WORD_SIZE], ) { + // temporary, as cpu trace generation currently works using word-addressing + let address = F::from_canonical_usize(WORD_SIZE) * address; + assert!(address_space != F::zero()); - assert!(timestamp % WORD_SIZE == 0); if let Some(last_timestamp) = self.last_timestamp { assert!(timestamp > last_timestamp); } self.last_timestamp = Some(timestamp); - self.memory.insert((address_space, address), data); + for (i, &datum) in data.iter().enumerate() { + self.memory.insert((address_space, address + F::from_canonical_usize(i)), datum); + } self.accesses.push(MemoryAccess { timestamp, op_type: OpType::Write, address_space, address, - data: data.clone(), + data, }); } } diff --git a/vm/src/memory/offline_checker/trace.rs b/vm/src/memory/offline_checker/trace.rs index a8a912248d..0b674d1a48 100644 --- a/vm/src/memory/offline_checker/trace.rs +++ b/vm/src/memory/offline_checker/trace.rs @@ -24,7 +24,6 @@ impl MemoryChip { pub fn generate_trace( &mut self, range_checker: Arc, - trace_degree: usize, ) -> RowMajorMatrix { self.accesses .sort_by_key(|op| (op.address_space, op.address, op.timestamp)); @@ -60,7 +59,8 @@ impl MemoryChip { } // Ensure that trace degree is a power of two - assert!(trace_degree > 0 && trace_degree & (trace_degree - 1) == 0); + let trace_degree = self.accesses.len().next_power_of_two(); + if self.accesses.len() < trace_degree { rows.extend(self.generate_trace_row( @@ -99,7 +99,7 @@ impl MemoryChip { row.push(F::from_canonical_usize(curr_op.timestamp)); row.push(curr_op.address_space); row.push(curr_op.address); - row.extend(curr_op.data.clone()); + row.extend(curr_op.data); row.push(F::from_canonical_u8(op_type)); let same_addr_space = if curr_op.address_space == prev_op.address_space { diff --git a/vm/src/memory/tests.rs b/vm/src/memory/tests.rs index 7f0b0f4af1..028f171d4d 100644 --- a/vm/src/memory/tests.rs +++ b/vm/src/memory/tests.rs @@ -26,22 +26,21 @@ const TRACE_DEGREE: usize = 16; #[test] fn test_offline_checker() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let memory_chip = MemoryChip::new( + let mut chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + memory_chip.air.mem_width(), true, MEMORY_BUS); - + let requester = DummyInteractionAir::new(2 + chip.air.mem_width(), true, MEMORY_BUS); - let ops: Vec> = vec![ + let ops: Vec> = vec![ MemoryAccess { timestamp: 1, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(232), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -50,9 +49,9 @@ fn test_offline_checker() { MemoryAccess { timestamp: 0, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(2324), BabyBear::from_canonical_usize(433), BabyBear::from_canonical_usize(1778), @@ -63,7 +62,7 @@ fn test_offline_checker() { op_type: OpType::Write, address_space: BabyBear::one(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(231), BabyBear::from_canonical_usize(3883), BabyBear::from_canonical_usize(17), @@ -72,9 +71,9 @@ fn test_offline_checker() { MemoryAccess { timestamp: 2, op_type: OpType::Read, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(232), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -85,7 +84,7 @@ fn test_offline_checker() { op_type: OpType::Read, address_space: BabyBear::two(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(4382), BabyBear::from_canonical_usize(8837), BabyBear::from_canonical_usize(192), @@ -96,7 +95,7 @@ fn test_offline_checker() { op_type: OpType::Write, address_space: BabyBear::two(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(4382), BabyBear::from_canonical_usize(8837), BabyBear::from_canonical_usize(192), @@ -105,22 +104,34 @@ fn test_offline_checker() { MemoryAccess { timestamp: 3, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(3243), BabyBear::from_canonical_usize(3214), BabyBear::from_canonical_usize(6639), ], }, ]; + let mut ops_sorted = ops.clone(); + ops_sorted.sort_by_key(|op| op.timestamp); + + for op in ops_sorted.iter() { + match op.op_type { + OpType::Read => { + assert_eq!(chip.read_word(op.timestamp, op.address_space, op.address), op.data); + } + OpType::Write => { + chip.write_word(op.timestamp, op.address_space, op.address, op.data); + } + } + } - let offline_checker_trace = - offline_checker.generate_trace(ops.clone(), range_checker.clone(), TRACE_DEGREE); + let trace = chip.generate_trace(range_checker.clone()); let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( ops.iter() - .flat_map(|op: &MemoryAccess| { + .flat_map(|op: &MemoryAccess| { [ BabyBear::one(), BabyBear::from_canonical_usize(op.timestamp), @@ -143,8 +154,8 @@ fn test_offline_checker() { ); run_simple_test_no_pis( - vec![&offline_checker, &range_checker.air, &requester], - vec![offline_checker_trace, range_checker_trace, requester_trace], + vec![&chip.air, &range_checker.air, &requester], + vec![trace, range_checker_trace, requester_trace], ) .expect("Verification failed"); } @@ -162,8 +173,9 @@ fn test_offline_checker_negative_invalid_read() { // should fail because we can't read before writing memory_chip.write_word(0, BabyBear::one(), BabyBear::zero(), [BabyBear::zero(), BabyBear::zero(), BabyBear::zero()]); + memory_chip.accesses[0].op_type = OpType::Read; - let memory_trace = memory_chip.generate_trace(range_checker.clone(), TRACE_DEGREE); + let memory_trace = memory_chip.generate_trace(range_checker.clone()); let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( memory_chip.accesses.iter() @@ -202,22 +214,21 @@ fn test_offline_checker_negative_invalid_read() { #[test] fn test_offline_checker_negative_data_mismatch() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let offline_checker = OfflineChecker::new( - WORD_SIZE, + let mut chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + offline_checker.mem_width(), true, MEMORY_BUS); + let requester = DummyInteractionAir::new(2 + chip.air.mem_width(), true, MEMORY_BUS); - let ops: Vec> = vec![ + let ops: Vec> = vec![ MemoryAccess { timestamp: 0, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(2324), BabyBear::from_canonical_usize(433), BabyBear::from_canonical_usize(1778), @@ -226,9 +237,9 @@ fn test_offline_checker_negative_data_mismatch() { MemoryAccess { timestamp: 1, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(232), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -238,9 +249,9 @@ fn test_offline_checker_negative_data_mismatch() { MemoryAccess { timestamp: 2, op_type: OpType::Read, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(233), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -248,12 +259,14 @@ fn test_offline_checker_negative_data_mismatch() { }, ]; - let offline_checker_trace = - offline_checker.generate_trace(ops.clone(), range_checker.clone(), TRACE_DEGREE); + chip.accesses.clone_from(&ops); + + let trace = chip.generate_trace(range_checker.clone()); + let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( ops.iter() - .flat_map(|op: &MemoryAccess| { + .flat_map(|op: &MemoryAccess| { iter::once(BabyBear::one()) .chain(iter::once(BabyBear::from_canonical_usize(op.timestamp))) .chain(iter::once(BabyBear::from_canonical_u8(op.op_type as u8))) @@ -277,8 +290,8 @@ fn test_offline_checker_negative_data_mismatch() { }); assert_eq!( run_simple_test_no_pis( - vec![&offline_checker, &range_checker.air, &requester], - vec![offline_checker_trace, range_checker_trace, requester_trace], + vec![&chip.air, &range_checker.air, &requester], + vec![trace, range_checker_trace, requester_trace], ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" diff --git a/vm/src/program/mod.rs b/vm/src/program/mod.rs index a79dae56cf..00531822e0 100644 --- a/vm/src/program/mod.rs +++ b/vm/src/program/mod.rs @@ -17,7 +17,7 @@ pub struct ProgramAir { pub struct ProgramChip { pub air: ProgramAir, - execution_frequencies: Vec, + pub execution_frequencies: Vec, } impl ProgramChip { diff --git a/vm/src/program/tests/mod.rs b/vm/src/program/tests/mod.rs index 28fbe0c7fa..b77426adfd 100644 --- a/vm/src/program/tests/mod.rs +++ b/vm/src/program/tests/mod.rs @@ -5,7 +5,7 @@ use p3_field::AbstractField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use crate::cpu::{CpuAir, CpuOptions, READ_INSTRUCTION_BUS}; +use crate::cpu::READ_INSTRUCTION_BUS; use crate::cpu::{trace::Instruction, OpCode::*}; use crate::program::columns::ProgramPreprocessedCols; @@ -26,20 +26,20 @@ fn test_flatten_fromslice_roundtrip() { assert_eq!(num_cols, flattened.len()); } -fn interaction_test(is_field_arithmetic_enabled: bool, program: Vec>) { - let cpu_air = CpuAir::<1>::new(CpuOptions { - field_arithmetic_enabled: is_field_arithmetic_enabled, - }); - let execution = cpu_air.generate_program_execution(program.clone()).unwrap(); - - let chip = ProgramChip::new(program); +fn interaction_test(program: Vec>, execution: Vec) { + let mut chip = ProgramChip::new(program.clone()); + let mut execution_frequencies = vec![0; program.len()]; + for pc in execution { + execution_frequencies[pc] += 1; + chip.get_instruction(pc); + } let trace = chip.generate_trace(); let counter_air = DummyInteractionAir::new(7, true, READ_INSTRUCTION_BUS); let mut program_rows = vec![]; - for (pc, instruction) in execution.program.iter().enumerate() { + for (pc, instruction) in program.iter().enumerate() { program_rows.extend(vec![ - execution.execution_frequencies[pc], + BabyBear::from_canonical_usize(execution_frequencies[pc]), BabyBear::from_canonical_usize(pc), BabyBear::from_canonical_usize(instruction.opcode as usize), instruction.op_a, @@ -80,13 +80,11 @@ fn test_program_1() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - interaction_test(true, program.clone()); + interaction_test(program.clone(), vec![0, 3, 2, 5]); } #[test] fn test_program_without_field_arithmetic() { - let field_arithmetic_enabled = false; - // see cpu/tests/mod.rs let program = vec![ // word[0]_1 <- word[5]_0 @@ -101,7 +99,7 @@ fn test_program_without_field_arithmetic() { Instruction::from_isize(BEQ, 0, 5, -1, 1, 0), ]; - interaction_test(field_arithmetic_enabled, program.clone()); + interaction_test(program.clone(), vec![0, 2, 4, 1]); } #[test] @@ -113,19 +111,18 @@ fn test_program_negative() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let cpu_air = CpuAir::<1>::new(CpuOptions { - field_arithmetic_enabled: true, - }); - let execution = cpu_air.generate_program_execution(program.clone()).unwrap(); - - let chip = ProgramChip::new(program); + let mut chip = ProgramChip::new(program.clone()); + let execution_frequencies = vec![1; program.len()]; + for pc in 0..program.len() { + chip.get_instruction(pc); + } let trace = chip.generate_trace(); let counter_air = DummyInteractionAir::new(7, true, READ_INSTRUCTION_BUS); let mut program_rows = vec![]; - for (pc, instruction) in execution.program.iter().enumerate() { + for (pc, instruction) in program.iter().enumerate() { program_rows.extend(vec![ - execution.execution_frequencies[pc], + BabyBear::from_canonical_usize(execution_frequencies[pc]), BabyBear::from_canonical_usize(pc), BabyBear::from_canonical_usize(instruction.opcode as usize), instruction.op_a, diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 479f3559bb..820b74836b 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -3,15 +3,18 @@ use std::sync::Arc; use afs_chips::range_gate::RangeCheckerGateChip; use afs_stark_backend::rap::AnyRap; use p3_field::PrimeField32; -use p3_matrix::dense::DenseMatrix; +use p3_matrix::{dense::DenseMatrix, Matrix}; use p3_uni_stark::{StarkGenericConfig, Val}; use p3_util::log2_strict_usize; +pub enum Void {} + use crate::{ cpu::{ trace::{ExecutionError, Instruction}, CpuAir, RANGE_CHECKER_BUS, }, + field_arithmetic::FieldArithmeticChip, memory::offline_checker::MemoryChip, program::ProgramChip, }; @@ -28,10 +31,12 @@ pub struct VirtualMachine { pub memory_chip: MemoryChip, pub field_arithmetic_chip: FieldArithmeticChip, pub range_checker: Arc, + + traces: Vec>, } impl VirtualMachine { - pub fn new(config: VmConfig, program: Vec>) -> Result { + pub fn new(config: VmConfig, program: Vec>) -> Self { let config = config.vm; let decomp = config.decomp; let limb_bits = config.limb_bits; @@ -43,17 +48,34 @@ impl VirtualMachine { let memory_chip = MemoryChip::new(limb_bits, limb_bits, limb_bits, decomp); let field_arithmetic_chip = FieldArithmeticChip::new(); - Ok(Self { + Self { config, cpu_air, program_chip, memory_chip, field_arithmetic_chip, range_checker, - }) + traces: vec![], + } + } + + fn generate_traces(&mut self) -> Result>, ExecutionError> { + let cpu_trace = CpuAir::generate_trace(self)?; + Ok(vec![ + cpu_trace, + self.program_chip.generate_trace(), + self.memory_chip.generate_trace(self.range_checker.clone()), + self.field_arithmetic_chip.generate_trace(), + self.range_checker.generate_trace(), + ]) } - pub fn generate_traces(&self) -> Vec>> {} + pub fn traces(&mut self) -> Result>, ExecutionError> { + if self.traces.is_empty() { + self.traces = self.generate_traces()?; + } + Ok(self.traces.clone()) + } /*fn max_trace_heights(&self) -> Vec { let max_operations = self.config.max_operations; @@ -71,35 +93,35 @@ impl VirtualMachine { .collect() }*/ - pub fn max_log_degree(&self) -> usize { + pub fn max_log_degree(&mut self) -> Result { let mut checker_trace_degree = 0; - for trace in self.traces() { + for trace in self.traces()? { checker_trace_degree = std::cmp::max(checker_trace_degree, trace.height()); } - log2_strict_usize(checker_trace_degree) + Ok(log2_strict_usize(checker_trace_degree)) } } -impl VirtualMachine> +pub fn get_chips( + vm: &VirtualMachine>, +) -> Vec<&dyn AnyRap> where Val: PrimeField32, { - pub fn chips(&self) -> Vec<&dyn AnyRap> { - if self.config.field_arithmetic_enabled { - vec![ - &self.cpu_air, - &self.program_chip.air, - &self.memory_chip.air, - &self.field_arithmetic_chip.air, - &self.range_checker.air, - ] - } else { - vec![ - &self.cpu_air, - &self.program_chip.air, - &self.memory_chip.air, - &self.range_checker.air, - ] - } + if vm.config.field_arithmetic_enabled { + vec![ + &vm.cpu_air, + &vm.program_chip.air, + &vm.memory_chip.air, + &vm.field_arithmetic_chip.air, + &vm.range_checker.air, + ] + } else { + vec![ + &vm.cpu_air, + &vm.program_chip.air, + &vm.memory_chip.air, + &vm.range_checker.air, + ] } } diff --git a/vm/tests/integration_test.rs b/vm/tests/integration_test.rs index bccb7afb39..643bd4581f 100644 --- a/vm/tests/integration_test.rs +++ b/vm/tests/integration_test.rs @@ -5,6 +5,7 @@ use stark_vm::cpu::trace::Instruction; use stark_vm::cpu::OpCode::*; use stark_vm::vm::config::VmConfig; use stark_vm::vm::config::VmParamsConfig; +use stark_vm::vm::get_chips; use stark_vm::vm::VirtualMachine; const WORD_SIZE: usize = 1; @@ -12,7 +13,7 @@ const LIMB_BITS: usize = 8; const DECOMP: usize = 4; fn air_test(field_arithmetic_enabled: bool, program: Vec>) { - let vm = VirtualMachine::::new( + let mut vm = VirtualMachine::::new( VmConfig { vm: VmParamsConfig { field_arithmetic_enabled, @@ -21,10 +22,9 @@ fn air_test(field_arithmetic_enabled: bool, program: Vec>) }, }, program, - ) - .unwrap(); - let chips = vm.chips(); - let traces = vm.traces(); + ); + let traces = vm.traces().unwrap(); + let chips = get_chips(&vm); run_simple_test_no_pis(chips, traces).expect("Verification failed"); } From 90d2ed176c46e7d61bae0fcd56d263296c4f46e2 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 12:26:04 -0400 Subject: [PATCH 3/9] All tests passing --- vm/src/cpu/tests/mod.rs | 42 ++++++++++++++++++++++------ vm/src/cpu/trace.rs | 10 +++---- vm/src/memory/offline_checker/mod.rs | 8 ++---- vm/src/vm/mod.rs | 20 ++++++++----- vm/tests/integration_test.rs | 2 ++ 5 files changed, 56 insertions(+), 26 deletions(-) diff --git a/vm/src/cpu/tests/mod.rs b/vm/src/cpu/tests/mod.rs index ec8e9d62b6..e0ecccb5f0 100644 --- a/vm/src/cpu/tests/mod.rs +++ b/vm/src/cpu/tests/mod.rs @@ -19,7 +19,7 @@ use super::trace::isize_to_field; use super::{decompose, ARITHMETIC_BUS, MEMORY_BUS, READ_INSTRUCTION_BUS}; use super::{trace::Instruction, OpCode::*}; -const TEST_WORD_SIZE: usize = 4; +const TEST_WORD_SIZE: usize = 1; const LIMB_BITS: usize = 8; const DECOMP: usize = 4; @@ -112,7 +112,13 @@ fn execution_test( let mut vm = make_vm(program.clone(), field_arithmetic_enabled); let mut trace = CpuAir::generate_trace(&mut vm).unwrap(); - assert_eq!(vm.memory_chip.accesses, expected_memory_log); + let mut actual_memory_log = vm.memory_chip.accesses.clone(); + // temporary + for access in actual_memory_log.iter_mut() { + access.address = access.address / F::from_canonical_usize(WORD_SIZE); + } + + assert_eq!(actual_memory_log, expected_memory_log); assert_eq!( vm.field_arithmetic_chip.operations, expected_arithmetic_operations @@ -124,7 +130,7 @@ fn execution_test( assert_eq!(trace.height(), expected_execution.len()); for (i, &pc) in expected_execution.iter().enumerate() { - let cols = CpuCols::::from_slice(trace.row_mut(i), vm.config.cpu_options()); + let cols = CpuCols::::from_slice(trace.row_mut(i), vm.options()); let expected_io = CpuIoCols { clock_cycle: F::from_canonical_u64(i as u64), pc: F::from_canonical_u64(pc as u64), @@ -189,13 +195,13 @@ fn air_test_change< for i in 0..trace.height() { rows.push(CpuCols::::from_slice( trace.row_mut(i), - vm.config.cpu_options(), + vm.options(), )); } change(&mut rows, &mut vm); let mut flattened = vec![]; for row in rows { - flattened.extend(row.flatten(vm.config.cpu_options())); + flattened.extend(row.flatten(vm.options())); } let trace = DenseMatrix::new(flattened, trace.width()); @@ -496,8 +502,10 @@ fn test_cpu_negative_disable_write() { #[test] #[should_panic(expected = "assertion `left == right` failed")] -fn test_cpu_negative_disable_read() { +fn test_cpu_negative_disable_read0() { let program = vec![ + // word[0]_1 <- 0 + Instruction::from_isize(STOREW, 0, 0, 0, 0, 1), // if word[0]_0 == word[0]_[0] then pc += 1 Instruction::from_isize(LOADW, 0, 0, 0, 1, 1), // terminate @@ -505,7 +513,25 @@ fn test_cpu_negative_disable_read() { ]; air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { - rows[0].aux.accesses[0].enabled = AbstractField::zero(); - vm.memory_chip.accesses.remove(0); + rows[1].aux.accesses[0].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(1); }); } + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn test_cpu_negative_disable_read1() { + let program = vec![ + // word[0]_1 <- 0 + Instruction::from_isize(STOREW, 0, 0, 0, 0, 1), + // if word[0]_0 == word[0]_[0] then pc += 1 + Instruction::from_isize(LOADW, 0, 0, 0, 1, 1), + // terminate + Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), + ]; + + air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { + rows[1].aux.accesses[1].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(2); + }); +} \ No newline at end of file diff --git a/vm/src/cpu/trace.rs b/vm/src/cpu/trace.rs index 36e9fa4811..f81b6adb7f 100644 --- a/vm/src/cpu/trace.rs +++ b/vm/src/cpu/trace.rs @@ -223,7 +223,7 @@ impl CpuAir { next_pc = pc; } opcode @ (FADD | FSUB | FMUL | FDIV) => { - if vm.config.cpu_options().field_arithmetic_enabled { + if vm.options().field_arithmetic_enabled { // read from d[b] and e[c] let operand1 = read!(d, b); let operand2 = read!(e, c); @@ -244,7 +244,7 @@ impl CpuAir { }; let mut operation_flags = BTreeMap::new(); - for other_opcode in vm.config.cpu_options().enabled_instructions() { + for other_opcode in vm.options().enabled_instructions() { operation_flags.insert(other_opcode, F::from_bool(other_opcode == opcode)); } @@ -264,19 +264,19 @@ impl CpuAir { }; let cols = CpuCols { io, aux }; - rows.extend(cols.flatten(vm.config.cpu_options())); + rows.extend(cols.flatten(vm.options())); pc = next_pc; clock_cycle += 1; - if opcode == TERMINATE && rows.len().is_power_of_two() { + if opcode == TERMINATE && clock_cycle.is_power_of_two() { break; } } Ok(RowMajorMatrix::new( rows, - CpuCols::::get_width(vm.config.cpu_options()), + CpuCols::::get_width(vm.options()), )) } } diff --git a/vm/src/memory/offline_checker/mod.rs b/vm/src/memory/offline_checker/mod.rs index 8500613e4b..078348e62e 100644 --- a/vm/src/memory/offline_checker/mod.rs +++ b/vm/src/memory/offline_checker/mod.rs @@ -60,9 +60,6 @@ impl MemoryChip { } pub fn read_word(&mut self, timestamp: usize, address_space: F, address: F) -> [F; WORD_SIZE] { - // temporary, as cpu trace generation currently works using word-addressing - let address = F::from_canonical_usize(WORD_SIZE) * address; - assert!(address_space != F::zero()); if let Some(last_timestamp) = self.last_timestamp { assert!(timestamp > last_timestamp); @@ -76,6 +73,7 @@ impl MemoryChip { address, data, }); + println!("access: {:?}", self.accesses.last().unwrap()); data } @@ -86,9 +84,6 @@ impl MemoryChip { address: F, data: [F; WORD_SIZE], ) { - // temporary, as cpu trace generation currently works using word-addressing - let address = F::from_canonical_usize(WORD_SIZE) * address; - assert!(address_space != F::zero()); if let Some(last_timestamp) = self.last_timestamp { assert!(timestamp > last_timestamp); @@ -104,5 +99,6 @@ impl MemoryChip { address, data, }); + println!("access: {:?}", self.accesses.last().unwrap()); } } diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 820b74836b..a741d3904b 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -11,8 +11,7 @@ pub enum Void {} use crate::{ cpu::{ - trace::{ExecutionError, Instruction}, - CpuAir, RANGE_CHECKER_BUS, + trace::{ExecutionError, Instruction}, CpuAir, CpuOptions, RANGE_CHECKER_BUS }, field_arithmetic::FieldArithmeticChip, memory::offline_checker::MemoryChip, @@ -59,15 +58,22 @@ impl VirtualMachine { } } + pub fn options(&self) -> CpuOptions { + self.config.cpu_options() + } + fn generate_traces(&mut self) -> Result>, ExecutionError> { let cpu_trace = CpuAir::generate_trace(self)?; - Ok(vec![ + let mut result = vec![ cpu_trace, self.program_chip.generate_trace(), self.memory_chip.generate_trace(self.range_checker.clone()), - self.field_arithmetic_chip.generate_trace(), self.range_checker.generate_trace(), - ]) + ]; + if self.options().field_arithmetic_enabled { + result.push(self.field_arithmetic_chip.generate_trace()); + } + Ok(result) } pub fn traces(&mut self) -> Result>, ExecutionError> { @@ -108,13 +114,13 @@ pub fn get_chips( where Val: PrimeField32, { - if vm.config.field_arithmetic_enabled { + if vm.options().field_arithmetic_enabled { vec![ &vm.cpu_air, &vm.program_chip.air, &vm.memory_chip.air, - &vm.field_arithmetic_chip.air, &vm.range_checker.air, + &vm.field_arithmetic_chip.air, ] } else { vec![ diff --git a/vm/tests/integration_test.rs b/vm/tests/integration_test.rs index 643bd4581f..16b9294275 100644 --- a/vm/tests/integration_test.rs +++ b/vm/tests/integration_test.rs @@ -57,6 +57,8 @@ fn test_vm_1() { #[test] fn test_vm_without_field_arithmetic() { + std::env::set_var("RUST_BACKTRACE", "1"); + let field_arithmetic_enabled = false; /* From 0510b9761212330c11d9ddbc589e4a1a83c1aa05 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 13:01:53 -0400 Subject: [PATCH 4/9] Remove warning for read/write macro in cpu trace --- vm/src/cpu/trace.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vm/src/cpu/trace.rs b/vm/src/cpu/trace.rs index f81b6adb7f..1f1dd4ec30 100644 --- a/vm/src/cpu/trace.rs +++ b/vm/src/cpu/trace.rs @@ -7,13 +7,13 @@ use afs_chips::{ is_equal_vec::IsEqualVecAir, is_zero::IsZeroAir, sub_chip::LocalTraceInstructions, }; -use crate::{cpu::MAX_ACCESSES_PER_CYCLE, vm::VirtualMachine}; +use crate::vm::VirtualMachine; use super::{ columns::{CpuAuxCols, CpuCols, CpuIoCols, MemoryAccessCols}, compose, decompose, CpuAir, OpCode::{self, *}, - INST_WIDTH, MAX_READS_PER_CYCLE, + INST_WIDTH, MAX_ACCESSES_PER_CYCLE, MAX_READS_PER_CYCLE, MAX_WRITES_PER_CYCLE, }; #[derive(Copy, Clone, Debug, PartialEq, Eq, derive_new::new)] @@ -152,22 +152,22 @@ impl CpuAir { let mut next_pc = pc + F::one(); let mut accesses = [disabled_memory_cols(); MAX_ACCESSES_PER_CYCLE]; - let mut read_index = 0; - let mut write_index = MAX_READS_PER_CYCLE; + let mut num_reads = 0; + let mut num_writes = 0; macro_rules! read { ($address_space: expr, $address: expr) => {{ - assert!(read_index < MAX_READS_PER_CYCLE); - let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + read_index; + num_reads += 1; + assert!(num_reads <= MAX_READS_PER_CYCLE); + let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + (num_reads - 1); let data = if $address_space == F::zero() { decompose::($address) } else { vm.memory_chip .read_word(timestamp, $address_space, $address) }; - accesses[read_index] = + accesses[num_reads - 1] = memory_access_to_cols(true, $address_space, $address, data); - read_index += 1; compose(data) }}; } @@ -179,9 +179,8 @@ impl CpuAir { let word = decompose($data); vm.memory_chip .write_word(timestamp, $address_space, $address, word); - accesses[write_index] = + accesses[MAX_READS_PER_CYCLE + num_writes - 1] = memory_access_to_cols(true, $address_space, $address, word); - write_index += 1; }}; } From 0d7d797ccc250a719955f3a00183e2acc0171ab5 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 13:02:28 -0400 Subject: [PATCH 5/9] Add in change that got left out --- vm/src/cpu/trace.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vm/src/cpu/trace.rs b/vm/src/cpu/trace.rs index 1f1dd4ec30..65117c8ce5 100644 --- a/vm/src/cpu/trace.rs +++ b/vm/src/cpu/trace.rs @@ -174,8 +174,10 @@ impl CpuAir { macro_rules! write { ($address_space: expr, $address: expr, $data: expr) => {{ - assert!(write_index < MAX_ACCESSES_PER_CYCLE); - let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + write_index; + num_writes += 1; + assert!(num_writes <= MAX_WRITES_PER_CYCLE); + let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + + (MAX_READS_PER_CYCLE + num_writes - 1); let word = decompose($data); vm.memory_chip .write_word(timestamp, $address_space, $address, word); From 4d87dddbe29c6d1daea8621bda7f97232c8ea5c1 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 13:13:01 -0400 Subject: [PATCH 6/9] Lint field_arithmetic/trace.rs --- vm/src/field_arithmetic/trace.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vm/src/field_arithmetic/trace.rs b/vm/src/field_arithmetic/trace.rs index b7e3f02870..5567a912d9 100644 --- a/vm/src/field_arithmetic/trace.rs +++ b/vm/src/field_arithmetic/trace.rs @@ -1,7 +1,10 @@ use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; -use super::{columns::{FieldArithmeticAuxCols, FieldArithmeticCols, FieldArithmeticIOCols}, FieldArithmeticChip}; +use super::{ + columns::{FieldArithmeticAuxCols, FieldArithmeticCols, FieldArithmeticIOCols}, + FieldArithmeticChip, +}; use crate::cpu::OpCode; use super::FieldArithmeticAir; From 91d4e82ad439268491a56087cbc8df4bd2b89572 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 13:16:03 -0400 Subject: [PATCH 7/9] Fix lint, remove rust backtrace line --- vm/src/cpu/tests/mod.rs | 99 ++++++++++++++++---------- vm/src/memory/offline_checker/mod.rs | 3 +- vm/src/memory/offline_checker/trace.rs | 1 - vm/src/memory/tests.rs | 16 ++++- vm/src/vm/mod.rs | 3 +- vm/tests/integration_test.rs | 2 - 6 files changed, 79 insertions(+), 45 deletions(-) diff --git a/vm/src/cpu/tests/mod.rs b/vm/src/cpu/tests/mod.rs index e0ecccb5f0..87db140c63 100644 --- a/vm/src/cpu/tests/mod.rs +++ b/vm/src/cpu/tests/mod.rs @@ -446,10 +446,15 @@ fn test_cpu_negative_hasnt_terminated() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { - rows.remove(rows.len() - 1); - vm.program_chip.execution_frequencies[1] = 0; - }); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows.remove(rows.len() - 1); + vm.program_chip.execution_frequencies[1] = 0; + }, + ); } #[test] @@ -462,26 +467,31 @@ fn test_cpu_negative_secret_write() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { - let is_zero_air = IsZeroAir; - let mut is_zero_trace = is_zero_air - .generate_trace(vec![AbstractField::one()]) - .clone(); - let is_zero_aux = is_zero_trace.row_mut(0)[2]; - - rows[0].aux.accesses[2] = MemoryAccessCols { - enabled: AbstractField::one(), - address_space: AbstractField::one(), - is_immediate: AbstractField::zero(), - is_zero_aux, - address: AbstractField::zero(), - data: decompose(AbstractField::from_canonical_usize(115)), - }; - - vm.memory_chip - .accesses - .push(MemoryAccess::from_isize(0, OpType::Write, 1, 0, 115)); - }); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + let is_zero_air = IsZeroAir; + let mut is_zero_trace = is_zero_air + .generate_trace(vec![AbstractField::one()]) + .clone(); + let is_zero_aux = is_zero_trace.row_mut(0)[2]; + + rows[0].aux.accesses[2] = MemoryAccessCols { + enabled: AbstractField::one(), + address_space: AbstractField::one(), + is_immediate: AbstractField::zero(), + is_zero_aux, + address: AbstractField::zero(), + data: decompose(AbstractField::from_canonical_usize(115)), + }; + + vm.memory_chip + .accesses + .push(MemoryAccess::from_isize(0, OpType::Write, 1, 0, 115)); + }, + ); } #[test] @@ -494,10 +504,15 @@ fn test_cpu_negative_disable_write() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { - rows[0].aux.accesses[2].enabled = AbstractField::zero(); - vm.memory_chip.accesses.remove(0); - }); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows[0].aux.accesses[2].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(0); + }, + ); } #[test] @@ -512,10 +527,15 @@ fn test_cpu_negative_disable_read0() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { - rows[1].aux.accesses[0].enabled = AbstractField::zero(); - vm.memory_chip.accesses.remove(1); - }); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows[1].aux.accesses[0].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(1); + }, + ); } #[test] @@ -530,8 +550,13 @@ fn test_cpu_negative_disable_read1() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change(true, program, true, |rows, vm: &mut VirtualMachine| { - rows[1].aux.accesses[1].enabled = AbstractField::zero(); - vm.memory_chip.accesses.remove(2); - }); -} \ No newline at end of file + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows[1].aux.accesses[1].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(2); + }, + ); +} diff --git a/vm/src/memory/offline_checker/mod.rs b/vm/src/memory/offline_checker/mod.rs index 078348e62e..9da116cbf7 100644 --- a/vm/src/memory/offline_checker/mod.rs +++ b/vm/src/memory/offline_checker/mod.rs @@ -90,7 +90,8 @@ impl MemoryChip { } self.last_timestamp = Some(timestamp); for (i, &datum) in data.iter().enumerate() { - self.memory.insert((address_space, address + F::from_canonical_usize(i)), datum); + self.memory + .insert((address_space, address + F::from_canonical_usize(i)), datum); } self.accesses.push(MemoryAccess { timestamp, diff --git a/vm/src/memory/offline_checker/trace.rs b/vm/src/memory/offline_checker/trace.rs index 0b674d1a48..0074ebb741 100644 --- a/vm/src/memory/offline_checker/trace.rs +++ b/vm/src/memory/offline_checker/trace.rs @@ -61,7 +61,6 @@ impl MemoryChip { // Ensure that trace degree is a power of two let trace_degree = self.accesses.len().next_power_of_two(); - if self.accesses.len() < trace_degree { rows.extend(self.generate_trace_row( false, diff --git a/vm/src/memory/tests.rs b/vm/src/memory/tests.rs index 028f171d4d..533b7152a4 100644 --- a/vm/src/memory/tests.rs +++ b/vm/src/memory/tests.rs @@ -119,7 +119,10 @@ fn test_offline_checker() { for op in ops_sorted.iter() { match op.op_type { OpType::Read => { - assert_eq!(chip.read_word(op.timestamp, op.address_space, op.address), op.data); + assert_eq!( + chip.read_word(op.timestamp, op.address_space, op.address), + op.data + ); } OpType::Write => { chip.write_word(op.timestamp, op.address_space, op.address, op.data); @@ -172,13 +175,20 @@ fn test_offline_checker_negative_invalid_read() { let requester = DummyInteractionAir::new(2 + memory_chip.air.mem_width(), true, MEMORY_BUS); // should fail because we can't read before writing - memory_chip.write_word(0, BabyBear::one(), BabyBear::zero(), [BabyBear::zero(), BabyBear::zero(), BabyBear::zero()]); + memory_chip.write_word( + 0, + BabyBear::one(), + BabyBear::zero(), + [BabyBear::zero(), BabyBear::zero(), BabyBear::zero()], + ); memory_chip.accesses[0].op_type = OpType::Read; let memory_trace = memory_chip.generate_trace(range_checker.clone()); let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( - memory_chip.accesses.iter() + memory_chip + .accesses + .iter() .flat_map(|op: &MemoryAccess| { iter::once(BabyBear::one()) .chain(iter::once(BabyBear::from_canonical_usize(op.timestamp))) diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index a741d3904b..73faa6b3b7 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -11,7 +11,8 @@ pub enum Void {} use crate::{ cpu::{ - trace::{ExecutionError, Instruction}, CpuAir, CpuOptions, RANGE_CHECKER_BUS + trace::{ExecutionError, Instruction}, + CpuAir, CpuOptions, RANGE_CHECKER_BUS, }, field_arithmetic::FieldArithmeticChip, memory::offline_checker::MemoryChip, diff --git a/vm/tests/integration_test.rs b/vm/tests/integration_test.rs index 16b9294275..643bd4581f 100644 --- a/vm/tests/integration_test.rs +++ b/vm/tests/integration_test.rs @@ -57,8 +57,6 @@ fn test_vm_1() { #[test] fn test_vm_without_field_arithmetic() { - std::env::set_var("RUST_BACKTRACE", "1"); - let field_arithmetic_enabled = false; /* From 8e52196a6dd57a5a5d860eebef8c1447ea4076b8 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 13:40:06 -0400 Subject: [PATCH 8/9] Fix compiler/src/util.rs, remove debugging code --- compiler/src/util.rs | 28 ++++++++++++++++++---------- vm/src/memory/offline_checker/mod.rs | 2 -- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/compiler/src/util.rs b/compiler/src/util.rs index 5cee462154..bc927976e5 100644 --- a/compiler/src/util.rs +++ b/compiler/src/util.rs @@ -1,7 +1,10 @@ use p3_field::PrimeField32; -use stark_vm::cpu::{ - trace::{Instruction, ProgramExecution}, - CpuAir, CpuOptions, +use stark_vm::{ + cpu::trace::Instruction, + vm::{ + config::{VmConfig, VmParamsConfig}, + VirtualMachine, + }, }; pub fn canonical_i32_to_field(x: i32) -> F { @@ -14,13 +17,18 @@ pub fn canonical_i32_to_field(x: i32) -> F { } } -pub fn execute_program( - program: Vec>, -) -> ProgramExecution { - let cpu = CpuAir::new(CpuOptions { - field_arithmetic_enabled: true, - }); - cpu.generate_program_execution(program).unwrap() +pub fn execute_program(program: Vec>) { + let mut vm = VirtualMachine::::new( + VmConfig { + vm: VmParamsConfig { + field_arithmetic_enabled: true, + limb_bits: 28, + decomp: 4, + }, + }, + program, + ); + vm.traces().unwrap(); } pub fn display_program(program: &[Instruction]) { diff --git a/vm/src/memory/offline_checker/mod.rs b/vm/src/memory/offline_checker/mod.rs index 9da116cbf7..445a0110e0 100644 --- a/vm/src/memory/offline_checker/mod.rs +++ b/vm/src/memory/offline_checker/mod.rs @@ -73,7 +73,6 @@ impl MemoryChip { address, data, }); - println!("access: {:?}", self.accesses.last().unwrap()); data } @@ -100,6 +99,5 @@ impl MemoryChip { address, data, }); - println!("access: {:?}", self.accesses.last().unwrap()); } } From b48e4c8a54ad570eaf81f5e1b197774fa5c0f6f7 Mon Sep 17 00:00:00 2001 From: TlatoaniHJ Date: Tue, 2 Jul 2024 13:56:53 -0400 Subject: [PATCH 9/9] Comment out test_compiler_break --- compiler/tests/for_loops.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/tests/for_loops.rs b/compiler/tests/for_loops.rs index b5d5e947c1..c5139572db 100644 --- a/compiler/tests/for_loops.rs +++ b/compiler/tests/for_loops.rs @@ -175,8 +175,8 @@ fn test_compiler_break() { builder.halt(); - let program = builder.compile_isa(); - execute_program::(program); + // let program = builder.compile_isa(); + // execute_program::(program); // println!("{}", code);