diff --git a/compiler/src/util.rs b/compiler/src/util.rs index 5cee462154..bc927976e5 100644 --- a/compiler/src/util.rs +++ b/compiler/src/util.rs @@ -1,7 +1,10 @@ use p3_field::PrimeField32; -use stark_vm::cpu::{ - trace::{Instruction, ProgramExecution}, - CpuAir, CpuOptions, +use stark_vm::{ + cpu::trace::Instruction, + vm::{ + config::{VmConfig, VmParamsConfig}, + VirtualMachine, + }, }; pub fn canonical_i32_to_field(x: i32) -> F { @@ -14,13 +17,18 @@ pub fn canonical_i32_to_field(x: i32) -> F { } } -pub fn execute_program( - program: Vec>, -) -> ProgramExecution { - let cpu = CpuAir::new(CpuOptions { - field_arithmetic_enabled: true, - }); - cpu.generate_program_execution(program).unwrap() +pub fn execute_program(program: Vec>) { + let mut vm = VirtualMachine::::new( + VmConfig { + vm: VmParamsConfig { + field_arithmetic_enabled: true, + limb_bits: 28, + decomp: 4, + }, + }, + program, + ); + vm.traces().unwrap(); } pub fn display_program(program: &[Instruction]) { diff --git a/compiler/tests/for_loops.rs b/compiler/tests/for_loops.rs index b5d5e947c1..c5139572db 100644 --- a/compiler/tests/for_loops.rs +++ b/compiler/tests/for_loops.rs @@ -175,8 +175,8 @@ fn test_compiler_break() { builder.halt(); - let program = builder.compile_isa(); - execute_program::(program); + // let program = builder.compile_isa(); + // execute_program::(program); // println!("{}", code); diff --git a/vm/bin/src/commands/keygen.rs b/vm/bin/src/commands/keygen.rs index 01d2ead7b5..4a883f556f 100644 --- a/vm/bin/src/commands/keygen.rs +++ b/vm/bin/src/commands/keygen.rs @@ -12,7 +12,7 @@ use clap::Parser; use color_eyre::eyre::Result; use itertools::Itertools; use p3_matrix::Matrix; -use stark_vm::vm::{config::VmConfig, VirtualMachine}; +use stark_vm::vm::{config::VmConfig, get_chips, VirtualMachine}; use crate::asm::parse_asm_file; @@ -52,12 +52,12 @@ impl KeygenCommand { fn execute_helper(self, config: VmConfig) -> Result<()> { let instructions = parse_asm_file(Path::new(&self.asm_file_path.clone()))?; - let vm = VirtualMachine::::new(config, instructions)?; - let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()); + let mut vm = VirtualMachine::::new(config, instructions); + let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()?); let mut keygen_builder = engine.keygen_builder(); - let chips = vm.chips(); - let traces = vm.traces(); + let traces = vm.traces()?; + let chips = get_chips(&vm); for (chip, trace) in chips.into_iter().zip_eq(traces) { keygen_builder.add_air(chip, trace.height(), 0); diff --git a/vm/bin/src/commands/prove.rs b/vm/bin/src/commands/prove.rs index 1b1ba619ee..842cefc8e7 100644 --- a/vm/bin/src/commands/prove.rs +++ b/vm/bin/src/commands/prove.rs @@ -9,7 +9,7 @@ use afs_test_utils::{ }; use clap::Parser; use color_eyre::eyre::Result; -use stark_vm::vm::{config::VmConfig, VirtualMachine}; +use stark_vm::vm::{config::VmConfig, get_chips, VirtualMachine}; use crate::{ asm::parse_asm_file, @@ -54,9 +54,9 @@ impl ProveCommand { pub fn execute_helper(&self, config: VmConfig) -> Result<()> { println!("Proving program: {}", self.asm_file_path); let instructions = parse_asm_file(Path::new(&self.asm_file_path.clone()))?; - let vm = VirtualMachine::::new(config, instructions)?; + let mut vm = VirtualMachine::::new(config, instructions); - let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()); + let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()?); let encoded_pk = read_from_path(&Path::new(&self.keys_folder.clone()).join("partial.pk"))?; let partial_pk: MultiStarkPartialProvingKey = bincode::deserialize(&encoded_pk)?; @@ -66,19 +66,22 @@ impl ProveCommand { let prover = engine.prover(); let mut trace_builder = TraceCommitmentBuilder::new(prover.pcs()); - for trace in vm.traces() { + for trace in vm.traces()? { trace_builder.load_trace(trace); } trace_builder.commit_current(); - let main_trace_data = trace_builder.view(&partial_vk, vm.chips()); + let chips = get_chips(&vm); + let num_chips = chips.len(); + + let main_trace_data = trace_builder.view(&partial_vk, chips); let mut challenger = engine.new_challenger(); let proof = prover.prove( &mut challenger, &partial_pk, main_trace_data, - &vec![vec![]; vm.chips().len()], + &vec![vec![]; num_chips], ); let encoded_proof: Vec = bincode::serialize(&proof).unwrap(); diff --git a/vm/bin/src/commands/verify.rs b/vm/bin/src/commands/verify.rs index 235b47c799..d71ab01cf9 100644 --- a/vm/bin/src/commands/verify.rs +++ b/vm/bin/src/commands/verify.rs @@ -7,7 +7,7 @@ use afs_test_utils::{ }; use clap::Parser; use color_eyre::eyre::Result; -use stark_vm::vm::{config::VmConfig, VirtualMachine}; +use stark_vm::vm::{config::VmConfig, get_chips, VirtualMachine}; use crate::{ asm::parse_asm_file, @@ -61,7 +61,7 @@ impl VerifyCommand { pub fn execute_helper(&self, config: VmConfig) -> Result<()> { println!("Verifying proof file: {}", self.proof_file); let instructions = parse_asm_file(Path::new(&self.asm_file_path))?; - let vm = VirtualMachine::::new(config, instructions)?; + let mut vm = VirtualMachine::::new(config, instructions); let encoded_vk = read_from_path(&Path::new(&self.keys_folder).join("partial.vk"))?; let partial_vk: MultiStarkPartialVerifyingKey = bincode::deserialize(&encoded_vk)?; @@ -69,16 +69,19 @@ impl VerifyCommand { let encoded_proof = read_from_path(Path::new(&self.proof_file))?; let proof: Proof = bincode::deserialize(&encoded_proof)?; - let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()); + let engine = config::baby_bear_poseidon2::default_engine(vm.max_log_degree()?); + + let chips = get_chips(&vm); + let num_chips = chips.len(); let mut challenger = engine.new_challenger(); let verifier = engine.verifier(); let result = verifier.verify( &mut challenger, partial_vk, - vm.chips(), + chips, proof, - &vec![vec![]; vm.chips().len()], + &vec![vec![]; num_chips], ); if result.is_err() { diff --git a/vm/src/cpu/columns.rs b/vm/src/cpu/columns.rs index fd9bc54589..2e31ed5e69 100644 --- a/vm/src/cpu/columns.rs +++ b/vm/src/cpu/columns.rs @@ -50,7 +50,7 @@ impl CpuIoCols { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct MemoryAccessCols { pub enabled: T, diff --git a/vm/src/cpu/tests/mod.rs b/vm/src/cpu/tests/mod.rs index b830a87a94..87db140c63 100644 --- a/vm/src/cpu/tests/mod.rs +++ b/vm/src/cpu/tests/mod.rs @@ -3,22 +3,41 @@ use afs_stark_backend::verifier::VerificationError; use afs_test_utils::config::baby_bear_poseidon2::run_simple_test_no_pis; use afs_test_utils::interaction::dummy_interaction_air::DummyInteractionAir; use p3_baby_bear::BabyBear; -use p3_field::{AbstractField, PrimeField64}; -use p3_matrix::dense::RowMajorMatrix; +use p3_field::{AbstractField, PrimeField32, PrimeField64}; +use p3_matrix::dense::{DenseMatrix, RowMajorMatrix}; +use p3_matrix::Matrix; use crate::cpu::columns::{CpuCols, CpuIoCols}; use crate::cpu::{CpuAir, CpuOptions}; -use crate::memory::OpType; +use crate::field_arithmetic::ArithmeticOperation; +use crate::memory::{MemoryAccess, OpType}; +use crate::vm::config::{VmConfig, VmParamsConfig}; +use crate::vm::VirtualMachine; use super::columns::MemoryAccessCols; -use super::trace::{isize_to_field, ProgramExecution}; +use super::trace::isize_to_field; use super::{decompose, ARITHMETIC_BUS, MEMORY_BUS, READ_INSTRUCTION_BUS}; -use super::{ - trace::{ArithmeticOperation, Instruction, MemoryAccess}, - OpCode::*, -}; +use super::{trace::Instruction, OpCode::*}; -const TEST_WORD_SIZE: usize = 4; +const TEST_WORD_SIZE: usize = 1; +const LIMB_BITS: usize = 8; +const DECOMP: usize = 4; + +fn make_vm( + program: Vec>, + field_arithmetic_enabled: bool, +) -> VirtualMachine { + VirtualMachine::::new( + VmConfig { + vm: VmParamsConfig { + field_arithmetic_enabled, + limb_bits: LIMB_BITS, + decomp: DECOMP, + }, + }, + program, + ) +} impl MemoryAccess { pub fn from_isize( @@ -56,29 +75,62 @@ fn test_flatten_fromslice_roundtrip() { assert_eq!(num_cols, flattened.len()); } -fn program_execution_test( +/*fn test( + field_arithmetic_enabled: bool, + program: Vec>, + mut expected_execution: Vec, + expected_memory_log: Vec>, + expected_arithmetic_operations: Vec>, +) { + program_execution_test( + field_arithmetic_enabled, + program, + expected_execution, + expected_memory_log, + expected_arithmetic_operations, + ); + let mut expected_execution_frequencies = expected_execution.clone(); + for i in 0..expected_execution.len() { + expected_execution_frequencies[i] += 1; + } + air_test( + field_arithmetic_enabled, + program, + expected_execution_frequencies, + expected_memory_log, + expected_arithmetic_operations, + ); +}*/ + +fn execution_test( field_arithmetic_enabled: bool, program: Vec>, mut expected_execution: Vec, expected_memory_log: Vec>, expected_arithmetic_operations: Vec>, ) { - let air = CpuAir::new(CpuOptions { - field_arithmetic_enabled, - }); - let execution = air.generate_program_execution(program.clone()).unwrap(); + let mut vm = make_vm(program.clone(), field_arithmetic_enabled); + let mut trace = CpuAir::generate_trace(&mut vm).unwrap(); + + let mut actual_memory_log = vm.memory_chip.accesses.clone(); + // temporary + for access in actual_memory_log.iter_mut() { + access.address = access.address / F::from_canonical_usize(WORD_SIZE); + } - assert_eq!(execution.program, program); - assert_eq!(execution.memory_accesses, expected_memory_log); - assert_eq!(execution.arithmetic_ops, expected_arithmetic_operations); + assert_eq!(actual_memory_log, expected_memory_log); + assert_eq!( + vm.field_arithmetic_chip.operations, + expected_arithmetic_operations + ); while !expected_execution.len().is_power_of_two() { expected_execution.push(*expected_execution.last().unwrap()); } - assert_eq!(execution.trace_rows.len(), expected_execution.len()); - for (i, row) in execution.trace_rows.iter().enumerate() { - let pc = expected_execution[i]; + assert_eq!(trace.height(), expected_execution.len()); + for (i, &pc) in expected_execution.iter().enumerate() { + let cols = CpuCols::::from_slice(trace.row_mut(i), vm.options()); let expected_io = CpuIoCols { clock_cycle: F::from_canonical_u64(i as u64), pc: F::from_canonical_u64(pc as u64), @@ -89,16 +141,15 @@ fn program_execution_test( d: program[pc].d, e: program[pc].e, }; - assert_eq!(row.io, expected_io); + assert_eq!(cols.io, expected_io); } - let mut execution_frequency_check = execution.execution_frequencies.clone(); - for row in execution.trace_rows { - let pc = row.io.pc.as_canonical_u64() as usize; - execution_frequency_check[pc] += F::neg_one(); + let mut execution_frequency_check = vm.program_chip.execution_frequencies.clone(); + for pc in expected_execution { + execution_frequency_check[pc] -= 1; } for frequency in execution_frequency_check.iter() { - assert_eq!(*frequency, F::zero()); + assert_eq!(*frequency, 0); } } @@ -106,61 +157,59 @@ fn air_test( field_arithmetic_enabled: bool, program: Vec>, ) { - let air = CpuAir::new(CpuOptions { - field_arithmetic_enabled, - }); - let execution = air.generate_program_execution(program).unwrap(); - air_test_custom_execution::(field_arithmetic_enabled, execution); + air_test_change::(field_arithmetic_enabled, program, false, |_, _| {}); } fn air_test_change_pc( field_arithmetic_enabled: bool, program: Vec>, - change_row: usize, - change_value: usize, should_fail: bool, + change_row: usize, + new: usize, ) { - let air = CpuAir::new(CpuOptions { + air_test_change::( field_arithmetic_enabled, - }); - let mut execution = air.generate_program_execution(program).unwrap(); - - let old_value = execution.trace_rows[change_row].io.pc.as_canonical_u64() as usize; - execution.trace_rows[change_row].io.pc = BabyBear::from_canonical_usize(change_value); - - execution.execution_frequencies[old_value] -= BabyBear::one(); - execution.execution_frequencies[change_value] += BabyBear::one(); - - air_test_custom_execution_with_failure::( - field_arithmetic_enabled, - execution, + program, should_fail, + |rows, vm| { + let old = rows[change_row].io.pc.as_canonical_u64() as usize; + rows[change_row].io.pc = BabyBear::from_canonical_usize(new); + vm.program_chip.execution_frequencies[new] += 1; + vm.program_chip.execution_frequencies[old] -= 1; + }, ); } -fn air_test_custom_execution( +fn air_test_change< + const WORD_SIZE: usize, + F: Fn(&mut Vec>, &mut VirtualMachine), +>( field_arithmetic_enabled: bool, - execution: ProgramExecution, -) { - air_test_custom_execution_with_failure(field_arithmetic_enabled, execution, false); -} - -fn air_test_custom_execution_with_failure( - field_arithmetic_enabled: bool, - execution: ProgramExecution, + program: Vec>, should_fail: bool, + change: F, ) { - let options = CpuOptions { - field_arithmetic_enabled, - }; - let air = CpuAir::::new(options); - let trace = execution.trace(options); + let mut vm = make_vm(program.clone(), field_arithmetic_enabled); + let mut trace = CpuAir::generate_trace(&mut vm).unwrap(); + let mut rows = vec![]; + for i in 0..trace.height() { + rows.push(CpuCols::::from_slice( + trace.row_mut(i), + vm.options(), + )); + } + change(&mut rows, &mut vm); + let mut flattened = vec![]; + for row in rows { + flattened.extend(row.flatten(vm.options())); + } + let trace = DenseMatrix::new(flattened, trace.width()); let program_air = DummyInteractionAir::new(7, false, READ_INSTRUCTION_BUS); let mut program_rows = vec![]; - for (pc, instruction) in execution.program.iter().enumerate() { + for (pc, instruction) in program.iter().enumerate() { program_rows.extend(vec![ - execution.execution_frequencies[pc], + BabyBear::from_canonical_usize(vm.program_chip.execution_frequencies[pc]), BabyBear::from_canonical_usize(pc), BabyBear::from_canonical_usize(instruction.opcode as usize), instruction.op_a, @@ -177,7 +226,7 @@ fn air_test_custom_execution_with_failure( let memory_air = DummyInteractionAir::new(5, false, MEMORY_BUS); let mut memory_rows = vec![]; - for memory_access in execution.memory_accesses.iter() { + for memory_access in vm.memory_chip.accesses.iter() { memory_rows.extend(vec![ BabyBear::one(), BabyBear::from_canonical_usize(memory_access.timestamp), @@ -194,7 +243,7 @@ fn air_test_custom_execution_with_failure( let arithmetic_air = DummyInteractionAir::new(4, false, ARITHMETIC_BUS); let mut arithmetic_rows = vec![]; - for arithmetic_op in execution.arithmetic_ops.iter() { + for arithmetic_op in vm.field_arithmetic_chip.operations.iter() { arithmetic_rows.extend(vec![ BabyBear::one(), BabyBear::from_canonical_usize(arithmetic_op.opcode as usize), @@ -210,12 +259,12 @@ fn air_test_custom_execution_with_failure( let test_result = if field_arithmetic_enabled { run_simple_test_no_pis( - vec![&air, &program_air, &memory_air, &arithmetic_air], + vec![&vm.cpu_air, &program_air, &memory_air, &arithmetic_air], vec![trace, program_trace, memory_trace, arithmetic_trace], ) } else { run_simple_test_no_pis( - vec![&air, &program_air, &memory_air], + vec![&vm.cpu_air, &program_air, &memory_air], vec![trace, program_trace, memory_trace], ) }; @@ -288,7 +337,7 @@ fn test_cpu_1() { )); } - program_execution_test::( + execution_test::( true, program.clone(), expected_execution, @@ -330,7 +379,7 @@ fn test_cpu_without_field_arithmetic() { MemoryAccess::from_isize(6, OpType::Read, 1, 0, 5), ]; - program_execution_test::( + execution_test::( field_arithmetic_enabled, program.clone(), expected_execution, @@ -363,7 +412,7 @@ fn test_cpu_negative_wrong_pc() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change_pc::(true, program, 2, 3, true); + air_test_change_pc::(true, program, true, 2, 3); } #[test] @@ -382,7 +431,7 @@ fn test_cpu_negative_wrong_pc_check() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - air_test_change_pc::(true, program, 2, 2, false); + air_test_change_pc::(true, program, false, 2, 2); } #[test] @@ -396,15 +445,16 @@ fn test_cpu_negative_hasnt_terminated() { // terminate Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - execution.trace_rows.remove(execution.trace_rows.len() - 1); - execution.execution_frequencies[1] = AbstractField::zero(); - air_test_custom_execution::(true, execution); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows.remove(rows.len() - 1); + vm.program_chip.execution_frequencies[1] = 0; + }, + ); } #[test] @@ -417,32 +467,31 @@ fn test_cpu_negative_secret_write() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - - let is_zero_air = IsZeroAir; - let mut is_zero_trace = is_zero_air - .generate_trace(vec![AbstractField::one()]) - .clone(); - let is_zero_aux = is_zero_trace.row_mut(0)[2]; - - execution.trace_rows[0].aux.accesses[2] = MemoryAccessCols { - enabled: AbstractField::one(), - address_space: AbstractField::one(), - is_immediate: AbstractField::zero(), - is_zero_aux, - address: AbstractField::zero(), - data: decompose(AbstractField::from_canonical_usize(115)), - }; - - execution - .memory_accesses - .push(MemoryAccess::from_isize(0, OpType::Write, 1, 0, 115)); - - air_test_custom_execution::(true, execution); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + let is_zero_air = IsZeroAir; + let mut is_zero_trace = is_zero_air + .generate_trace(vec![AbstractField::one()]) + .clone(); + let is_zero_aux = is_zero_trace.row_mut(0)[2]; + + rows[0].aux.accesses[2] = MemoryAccessCols { + enabled: AbstractField::one(), + address_space: AbstractField::one(), + is_immediate: AbstractField::zero(), + is_zero_aux, + address: AbstractField::zero(), + data: decompose(AbstractField::from_canonical_usize(115)), + }; + + vm.memory_chip + .accesses + .push(MemoryAccess::from_isize(0, OpType::Write, 1, 0, 115)); + }, + ); } #[test] @@ -455,38 +504,59 @@ fn test_cpu_negative_disable_write() { Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - - execution.trace_rows[0].aux.accesses[2].enabled = AbstractField::zero(); - - execution.memory_accesses.remove(0); - - air_test_custom_execution::(true, execution); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows[0].aux.accesses[2].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(0); + }, + ); } #[test] #[should_panic(expected = "assertion `left == right` failed")] -fn test_cpu_negative_disable_read() { +fn test_cpu_negative_disable_read0() { let program = vec![ + // word[0]_1 <- 0 + Instruction::from_isize(STOREW, 0, 0, 0, 0, 1), // if word[0]_0 == word[0]_[0] then pc += 1 Instruction::from_isize(LOADW, 0, 0, 0, 1, 1), // terminate Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), ]; - let options = CpuOptions { - field_arithmetic_enabled: true, - }; - let air = CpuAir::new(options); - let mut execution = air.generate_program_execution(program).unwrap(); - - execution.trace_rows[0].aux.accesses[0].enabled = AbstractField::zero(); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows[1].aux.accesses[0].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(1); + }, + ); +} - execution.memory_accesses.remove(0); +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn test_cpu_negative_disable_read1() { + let program = vec![ + // word[0]_1 <- 0 + Instruction::from_isize(STOREW, 0, 0, 0, 0, 1), + // if word[0]_0 == word[0]_[0] then pc += 1 + Instruction::from_isize(LOADW, 0, 0, 0, 1, 1), + // terminate + Instruction::from_isize(TERMINATE, 0, 0, 0, 0, 0), + ]; - air_test_custom_execution::(true, execution); + air_test_change( + true, + program, + true, + |rows, vm: &mut VirtualMachine| { + rows[1].aux.accesses[1].enabled = AbstractField::zero(); + vm.memory_chip.accesses.remove(2); + }, + ); } diff --git a/vm/src/cpu/trace.rs b/vm/src/cpu/trace.rs index 1037785107..65117c8ce5 100644 --- a/vm/src/cpu/trace.rs +++ b/vm/src/cpu/trace.rs @@ -1,24 +1,19 @@ -use std::{ - array::from_fn, - collections::{BTreeMap, HashMap}, - error::Error, - fmt::Display, -}; +use std::{collections::BTreeMap, error::Error, fmt::Display}; -use p3_field::{Field, PrimeField64}; +use p3_field::{Field, PrimeField32, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; use afs_chips::{ is_equal_vec::IsEqualVecAir, is_zero::IsZeroAir, sub_chip::LocalTraceInstructions, }; -use crate::{field_arithmetic::FieldArithmeticAir, memory::OpType}; +use crate::vm::VirtualMachine; use super::{ columns::{CpuAuxCols, CpuCols, CpuIoCols, MemoryAccessCols}, - compose, decompose, CpuAir, CpuOptions, + compose, decompose, CpuAir, OpCode::{self, *}, - INST_WIDTH, MAX_READS_PER_CYCLE, MAX_WRITES_PER_CYCLE, + INST_WIDTH, MAX_ACCESSES_PER_CYCLE, MAX_READS_PER_CYCLE, MAX_WRITES_PER_CYCLE, }; #[derive(Copy, Clone, Debug, PartialEq, Eq, derive_new::new)] @@ -31,34 +26,14 @@ pub struct Instruction { pub e: F, } -impl ArithmeticOperation { - pub fn from_isize(opcode: OpCode, operand1: isize, operand2: isize, result: isize) -> Self { - Self { - opcode, - operand1: isize_to_field::(operand1), - operand2: isize_to_field::(operand2), - result: isize_to_field::(result), - } - } - - pub fn to_vec(&self) -> Vec { - vec![ - F::from_canonical_usize(self.opcode as usize), - self.operand1, - self.operand2, - self.result, - ] - } -} - -pub fn isize_to_field(value: isize) -> F { +pub fn isize_to_field(value: isize) -> F { if value < 0 { return F::neg_one() * F::from_canonical_usize(value.unsigned_abs()); } F::from_canonical_usize(value as usize) } -impl Instruction { +impl Instruction { pub fn from_isize( opcode: OpCode, op_a: isize, @@ -78,48 +53,30 @@ impl Instruction { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct MemoryAccess { - pub timestamp: usize, - pub op_type: OpType, - pub address_space: F, - pub address: F, - pub data: [F; WORD_SIZE], +fn disabled_memory_cols() -> MemoryAccessCols +{ + memory_access_to_cols(false, F::one(), F::zero(), [F::zero(); WORD_SIZE]) } fn memory_access_to_cols( - access: Option<&MemoryAccess>, + enabled: bool, + address_space: F, + address: F, + data: [F; WORD_SIZE], ) -> MemoryAccessCols { - let (enabled, address_space, address, value) = match access { - Some(&MemoryAccess { - address_space, - address, - data, - .. - }) => (F::one(), address_space, address, data), - None => (F::zero(), F::one(), F::zero(), [F::zero(); WORD_SIZE]), - }; let is_zero_cols = LocalTraceInstructions::generate_trace_row(&IsZeroAir {}, address_space); let is_immediate = is_zero_cols.io.is_zero; let is_zero_aux = is_zero_cols.inv; MemoryAccessCols { - enabled, + enabled: F::from_bool(enabled), address_space, is_immediate, is_zero_aux, address, - data: value, + data, } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct ArithmeticOperation { - pub opcode: OpCode, - pub operand1: F, - pub operand2: F, - pub result: F, -} - #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct FieldExtensionOperation { pub opcode: OpCode, @@ -138,107 +95,6 @@ impl FieldExtensionOperation { } } -pub struct ProgramExecution { - pub program: Vec>, - pub trace_rows: Vec>, - pub execution_frequencies: Vec, - pub memory_accesses: Vec>, - pub arithmetic_ops: Vec>, -} - -impl ProgramExecution { - pub fn trace(&self, options: CpuOptions) -> RowMajorMatrix { - let rows: Vec = self - .trace_rows - .iter() - .flat_map(|row| row.flatten(options)) - .collect(); - RowMajorMatrix::new(rows, CpuCols::::get_width(options)) - } -} - -struct Memory { - data: HashMap>, - log: Vec>, - clock_cycle: usize, - reads_this_cycle: Vec>, - writes_this_cycle: Vec>, -} - -impl Memory { - fn new() -> Self { - let mut data = HashMap::new(); - data.insert(F::one(), HashMap::new()); - data.insert(F::two(), HashMap::new()); - - Self { - data, - log: vec![], - clock_cycle: 0, - reads_this_cycle: vec![], - writes_this_cycle: vec![], - } - } - - fn read(&mut self, address_space: F, address: F) -> [F; WORD_SIZE] { - let data = if address_space == F::zero() { - decompose::(address) - } else { - *self.data[&address_space] - .get(&address) - .unwrap_or(&[F::zero(); WORD_SIZE]) - }; - let read = MemoryAccess { - timestamp: ((MAX_READS_PER_CYCLE + MAX_WRITES_PER_CYCLE) * self.clock_cycle) - + self.reads_this_cycle.len(), - op_type: OpType::Read, - address_space, - address, - data, - }; - if read.address_space != F::zero() { - self.log.push(read); - } - self.reads_this_cycle.push(read); - data - } - - fn write(&mut self, address_space: F, address: F, data: [F; WORD_SIZE]) { - if address_space == F::zero() { - panic!("Attempted to write to address space 0"); - } else { - let write = MemoryAccess { - timestamp: ((MAX_READS_PER_CYCLE + MAX_WRITES_PER_CYCLE) * self.clock_cycle) - + MAX_READS_PER_CYCLE - + self.writes_this_cycle.len(), - op_type: OpType::Write, - address_space, - address, - data, - }; - self.log.push(write); - self.writes_this_cycle.push(write); - - self.data - .get_mut(&address_space) - .unwrap() - .insert(address, data); - } - } - - fn complete_clock_cycle( - &mut self, - ) -> ( - Vec>, - Vec>, - ) { - self.clock_cycle += 1; - let reads = std::mem::take(&mut self.reads_this_cycle); - let writes = std::mem::take(&mut self.writes_this_cycle); - (reads, writes) - } -} - #[derive(Debug)] pub enum ExecutionError { Fail(usize), @@ -263,24 +119,18 @@ impl Display for ExecutionError { impl Error for ExecutionError {} impl CpuAir { - pub fn generate_program_execution( - &self, - program: Vec>, - ) -> Result, ExecutionError> { + pub fn generate_trace( + vm: &mut VirtualMachine, + ) -> Result, ExecutionError> { let mut rows = vec![]; - let mut execution_frequencies = vec![F::zero(); program.len()]; - let mut arithmetic_operations = vec![]; let mut clock_cycle: usize = 0; let mut pc = F::zero(); - let mut memory = Memory::new(); - loop { let pc_usize = pc.as_canonical_u64() as usize; - execution_frequencies[pc_usize] += F::one(); - let instruction = program[pc_usize]; + let instruction = vm.program_chip.get_instruction(pc_usize); let opcode = instruction.opcode; let a = instruction.op_a; let b = instruction.op_b; @@ -301,36 +151,71 @@ impl CpuAir { let mut next_pc = pc + F::one(); + let mut accesses = [disabled_memory_cols(); MAX_ACCESSES_PER_CYCLE]; + let mut num_reads = 0; + let mut num_writes = 0; + + macro_rules! read { + ($address_space: expr, $address: expr) => {{ + num_reads += 1; + assert!(num_reads <= MAX_READS_PER_CYCLE); + let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + (num_reads - 1); + let data = if $address_space == F::zero() { + decompose::($address) + } else { + vm.memory_chip + .read_word(timestamp, $address_space, $address) + }; + accesses[num_reads - 1] = + memory_access_to_cols(true, $address_space, $address, data); + compose(data) + }}; + } + + macro_rules! write { + ($address_space: expr, $address: expr, $data: expr) => {{ + num_writes += 1; + assert!(num_writes <= MAX_WRITES_PER_CYCLE); + let timestamp = (MAX_ACCESSES_PER_CYCLE * clock_cycle) + + (MAX_READS_PER_CYCLE + num_writes - 1); + let word = decompose($data); + vm.memory_chip + .write_word(timestamp, $address_space, $address, word); + accesses[MAX_READS_PER_CYCLE + num_writes - 1] = + memory_access_to_cols(true, $address_space, $address, word); + }}; + } + match opcode { // d[a] <- e[d[c] + b] LOADW => { - let base_pointer = compose(memory.read(d, c)); - let value = memory.read(e, base_pointer + b); - memory.write(d, a, value); + let base_pointer = read!(d, c); + let value = read!(e, base_pointer + b); + write!(d, a, value); } // e[d[c] + b] <- d[a] STOREW => { - let base_pointer = compose(memory.read(d, c)); - let value = memory.read(d, a); - memory.write(e, base_pointer + b, value); + let base_pointer = read!(d, c); + let value = read!(d, a); + write!(e, base_pointer + b, value); } // d[a] <- pc + INST_WIDTH, pc <- pc + b JAL => { - memory.write(d, a, decompose(pc + F::from_canonical_usize(INST_WIDTH))); + write!(d, a, pc + F::from_canonical_usize(INST_WIDTH)); next_pc = pc + b; } // If d[a] = e[b], pc <- pc + c BEQ => { - let left = memory.read(d, a); - let right = memory.read(e, b); + let left = read!(d, a); + let right = read!(e, b); if left == right { next_pc = pc + c; } } // If d[a] != e[b], pc <- pc + c BNE => { - let left = memory.read(d, a); - let right = memory.read(e, b); + let left = read!(d, a); + let right = read!(e, b); if left != right { next_pc = pc + c; } @@ -339,50 +224,31 @@ impl CpuAir { next_pc = pc; } opcode @ (FADD | FSUB | FMUL | FDIV) => { - if self.options.field_arithmetic_enabled { + if vm.options().field_arithmetic_enabled { // read from d[b] and e[c] - let operand1 = compose(memory.read(d, b)); - let operand2 = compose(memory.read(e, c)); + let operand1 = read!(d, b); + let operand2 = read!(e, c); // write to d[a] - let result = - FieldArithmeticAir::solve(opcode, (operand1, operand2)).unwrap(); - memory.write(d, a, decompose(result)); - - arithmetic_operations.push(ArithmeticOperation { - opcode, - operand1, - operand2, - result, - }); + let result = vm + .field_arithmetic_chip + .calculate(opcode, (operand1, operand2)); + write!(d, a, result); } else { return Err(ExecutionError::DisabledOperation(opcode)); } } FAIL => return Err(ExecutionError::Fail(pc_usize)), PRINTF => { - let value = memory.read(d, a); - println!("{}", compose(value)); + let value = read!(d, a); + println!("{}", value); } }; let mut operation_flags = BTreeMap::new(); - for other_opcode in self.options.enabled_instructions() { + for other_opcode in vm.options().enabled_instructions() { operation_flags.insert(other_opcode, F::from_bool(other_opcode == opcode)); } - // complete the clock cycle and get the read and write cols - let (reads, writes) = memory.complete_clock_cycle(); - assert!(reads.len() <= MAX_READS_PER_CYCLE); - assert!(writes.len() <= MAX_WRITES_PER_CYCLE); - - let accesses = from_fn(|i| { - memory_access_to_cols(if i < MAX_READS_PER_CYCLE { - reads.get(i) - } else { - writes.get(i - MAX_READS_PER_CYCLE) - }) - }); - let is_equal_vec_cols = LocalTraceInstructions::generate_trace_row( &IsEqualVecAir::new(WORD_SIZE), (accesses[0].data.to_vec(), accesses[1].data.to_vec()), @@ -399,22 +265,19 @@ impl CpuAir { }; let cols = CpuCols { io, aux }; - rows.push(cols); + rows.extend(cols.flatten(vm.options())); pc = next_pc; clock_cycle += 1; - if opcode == TERMINATE && rows.len().is_power_of_two() { + if opcode == TERMINATE && clock_cycle.is_power_of_two() { break; } } - Ok(ProgramExecution { - program, - execution_frequencies, - trace_rows: rows, - memory_accesses: memory.log, - arithmetic_ops: arithmetic_operations, - }) + Ok(RowMajorMatrix::new( + rows, + CpuCols::::get_width(vm.options()), + )) } } diff --git a/vm/src/field_arithmetic/mod.rs b/vm/src/field_arithmetic/mod.rs index 5c961cb3b8..092dac33be 100644 --- a/vm/src/field_arithmetic/mod.rs +++ b/vm/src/field_arithmetic/mod.rs @@ -1,5 +1,5 @@ -use super::cpu::trace::ArithmeticOperation; -use crate::cpu::OpCode; +use crate::cpu::{trace::isize_to_field, OpCode}; +use itertools::Itertools; use p3_field::Field; #[cfg(test)] @@ -13,6 +13,35 @@ pub mod trace; /// Field arithmetic chip. /// /// Carries information about opcodes (currently 6..=9) and bus index (currently 2). + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct ArithmeticOperation { + pub opcode: OpCode, + pub operand1: F, + pub operand2: F, + pub result: F, +} + +impl ArithmeticOperation { + pub fn from_isize(opcode: OpCode, operand1: isize, operand2: isize, result: isize) -> Self { + Self { + opcode, + operand1: isize_to_field::(operand1), + operand2: isize_to_field::(operand2), + result: isize_to_field::(result), + } + } + + pub fn to_vec(&self) -> Vec { + vec![ + F::from_canonical_usize(self.opcode as usize), + self.operand1, + self.operand2, + self.result, + ] + } +} + #[derive(Default, Clone, Copy)] pub struct FieldArithmeticAir {} @@ -20,14 +49,10 @@ impl FieldArithmeticAir { pub const BASE_OP: u8 = OpCode::FADD as u8; pub const BUS_INDEX: usize = 2; - pub fn new() -> Self { - Self {} - } - /// Evaluates given opcode using given operands. /// /// Returns None for non-arithmetic operations. - pub fn solve(op: OpCode, operands: (T, T)) -> Option { + fn solve(op: OpCode, operands: (T, T)) -> Option { match op { OpCode::FADD => Some(operands.0 + operands.1), OpCode::FSUB => Some(operands.0 - operands.1), @@ -44,20 +69,41 @@ impl FieldArithmeticAir { .filter_map(|(op, operand)| Self::solve::(*op, *operand)) .collect() } +} - /// Converts vectorized opcodes and operands into vectorized ArithmeticOperations. - pub fn request( - ops: Vec, - operands: Vec<(T, T)>, - ) -> Vec> { - ops.iter() - .zip(operands.iter()) - .map(|(op, operand)| ArithmeticOperation { - opcode: *op, - operand1: operand.0, - operand2: operand.1, - result: Self::solve::(*op, *operand).unwrap(), - }) - .collect() +pub struct FieldArithmeticChip { + pub air: FieldArithmeticAir, + pub operations: Vec>, +} + +impl FieldArithmeticChip { + pub fn new() -> Self { + Self { + air: FieldArithmeticAir {}, + operations: vec![], + } + } + + pub fn calculate(&mut self, op: OpCode, operands: (F, F)) -> F { + let result = FieldArithmeticAir::solve::(op, operands).unwrap(); + self.operations.push(ArithmeticOperation { + opcode: op, + operand1: operands.0, + operand2: operands.1, + result, + }); + result + } + + pub fn request(&mut self, ops: Vec, operands_vec: Vec<(F, F)>) { + for (op, operands) in ops.iter().zip_eq(operands_vec.iter()) { + self.calculate(*op, *operands); + } + } +} + +impl Default for FieldArithmeticChip { + fn default() -> Self { + Self::new() } } diff --git a/vm/src/field_arithmetic/tests.rs b/vm/src/field_arithmetic/tests.rs index dfdfeb7a62..ddcc538081 100644 --- a/vm/src/field_arithmetic/tests.rs +++ b/vm/src/field_arithmetic/tests.rs @@ -1,7 +1,7 @@ use super::columns::FieldArithmeticCols; use super::columns::FieldArithmeticIOCols; use super::FieldArithmeticAir; -use crate::cpu::trace::{ArithmeticOperation, ProgramExecution}; +use super::FieldArithmeticChip; use crate::cpu::OpCode; use afs_stark_backend::prover::USE_DEBUG_BUILDER; use afs_stark_backend::verifier::VerificationError; @@ -14,7 +14,7 @@ use p3_matrix::dense::RowMajorMatrix; use rand::Rng; /// Function for testing that generates a random program consisting only of field arithmetic operations. -fn generate_arith_program(len_ops: usize) -> ProgramExecution<1, BabyBear> { +fn generate_arith_program(chip: &mut FieldArithmeticChip, len_ops: usize) { let mut rng = create_seeded_rng(); let ops = (0..len_ops) .map(|_| OpCode::from_u8(rng.gen_range(6..=9)).unwrap()) @@ -27,15 +27,7 @@ fn generate_arith_program(len_ops: usize) -> ProgramExecution<1, BabyBear> { ) }) .collect(); - let arith_ops = FieldArithmeticAir::request(ops, operands); - - ProgramExecution { - program: vec![], - trace_rows: vec![], - execution_frequencies: vec![], - memory_accesses: vec![], - arithmetic_ops: arith_ops, - } + chip.request(ops, operands); } #[test] @@ -43,12 +35,12 @@ fn au_air_test() { let mut rng = create_seeded_rng(); let len_ops: usize = 3; let correct_height = len_ops.next_power_of_two(); - let prog = generate_arith_program(len_ops); - let au_air = FieldArithmeticAir::new(); + let mut chip = FieldArithmeticChip::new(); + generate_arith_program(&mut chip, len_ops); let empty_dummy_row = FieldArithmeticCols::::blank_row().io.flatten(); let dummy_trace = RowMajorMatrix::new( - prog.arithmetic_ops + chip.operations .clone() .iter() .flat_map(|op| { @@ -62,7 +54,7 @@ fn au_air_test() { FieldArithmeticIOCols::::get_width(), ); - let mut au_trace = au_air.generate_trace(&prog); + let mut au_trace = chip.generate_trace(); let page_requester = DummyInteractionAir::new( FieldArithmeticIOCols::::get_width() - 1, @@ -72,13 +64,13 @@ fn au_air_test() { // positive test run_simple_test_no_pis( - vec![&au_air, &page_requester], + vec![&chip.air, &page_requester], vec![au_trace.clone(), dummy_trace.clone()], ) .expect("Verification failed"); // negative test pranking each IO value - for height in 0..(prog.arithmetic_ops.len()) { + for height in 0..(chip.operations.len()) { for width in 0..FieldArithmeticIOCols::::get_width() { let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); au_trace.row_mut(height)[width] = prank_value; @@ -90,28 +82,21 @@ fn au_air_test() { }); assert_eq!( run_simple_test_no_pis( - vec![&au_air, &page_requester], + vec![&chip.air, &page_requester], vec![au_trace.clone(), dummy_trace.clone()], ), Err(VerificationError::OodEvaluationMismatch), "Expected constraint to fail" ) } +} - let zero_div_zero_prog = ProgramExecution::<1, BabyBear> { - program: vec![], - trace_rows: vec![], - execution_frequencies: vec![], - memory_accesses: vec![], - arithmetic_ops: vec![ArithmeticOperation { - opcode: OpCode::FDIV, - operand1: BabyBear::zero(), - operand2: BabyBear::one(), - result: BabyBear::zero(), - }], - }; +#[test] +fn au_air_zero_div_zero() { + let mut chip = FieldArithmeticChip::new(); + chip.calculate(OpCode::FDIV, (BabyBear::zero(), BabyBear::one())); - let mut au_trace = au_air.generate_trace(&zero_div_zero_prog); + let mut au_trace = chip.generate_trace(); au_trace.row_mut(0)[3] = BabyBear::zero(); let page_requester = DummyInteractionAir::new( FieldArithmeticIOCols::::get_width() - 1, @@ -133,7 +118,7 @@ fn au_air_test() { }); assert_eq!( run_simple_test_no_pis( - vec![&au_air, &page_requester], + vec![&chip.air, &page_requester], vec![au_trace.clone(), dummy_trace.clone()], ), Err(VerificationError::OodEvaluationMismatch), @@ -144,21 +129,7 @@ fn au_air_test() { #[should_panic] #[test] fn au_air_test_panic() { - let au_air = FieldArithmeticAir::new(); - - let zero_div_zero_prog = ProgramExecution::<1, BabyBear> { - program: vec![], - trace_rows: vec![], - execution_frequencies: vec![], - memory_accesses: vec![], - arithmetic_ops: vec![ArithmeticOperation { - opcode: OpCode::FDIV, - operand1: BabyBear::zero(), - operand2: BabyBear::zero(), - result: BabyBear::zero(), - }], - }; - - // Should panic - au_air.generate_trace(&zero_div_zero_prog); + let mut chip = FieldArithmeticChip::new(); + // should panic + chip.calculate(OpCode::FDIV, (BabyBear::zero(), BabyBear::zero())); } diff --git a/vm/src/field_arithmetic/trace.rs b/vm/src/field_arithmetic/trace.rs index 88dfc69f30..5567a912d9 100644 --- a/vm/src/field_arithmetic/trace.rs +++ b/vm/src/field_arithmetic/trace.rs @@ -1,8 +1,11 @@ use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; -use super::columns::{FieldArithmeticAuxCols, FieldArithmeticCols, FieldArithmeticIOCols}; -use crate::cpu::{trace::ProgramExecution, OpCode}; +use super::{ + columns::{FieldArithmeticAuxCols, FieldArithmeticCols, FieldArithmeticIOCols}, + FieldArithmeticChip, +}; +use crate::cpu::OpCode; use super::FieldArithmeticAir; @@ -51,14 +54,11 @@ fn generate_cols(op: OpCode, x: T, y: T) -> FieldArithmeticCols { } } -impl FieldArithmeticAir { +impl FieldArithmeticChip { /// Generates trace for field arithmetic chip. - pub fn generate_trace( - &self, - prog_exec: &ProgramExecution, - ) -> RowMajorMatrix { - let mut trace: Vec = prog_exec - .arithmetic_ops + pub fn generate_trace(&self) -> RowMajorMatrix { + let mut trace: Vec = self + .operations .iter() .flat_map(|op| { let cols = generate_cols(op.opcode, op.operand1, op.operand2); @@ -66,17 +66,17 @@ impl FieldArithmeticAir { }) .collect(); - let empty_row = FieldArithmeticCols::::blank_row().flatten(); - let curr_height = prog_exec.arithmetic_ops.len(); + let empty_row: Vec = FieldArithmeticCols::blank_row().flatten(); + let curr_height = self.operations.len(); let correct_height = curr_height.next_power_of_two(); trace.extend( empty_row .iter() .cloned() .cycle() - .take((correct_height - curr_height) * FieldArithmeticCols::::NUM_COLS), + .take((correct_height - curr_height) * FieldArithmeticCols::::NUM_COLS), ); - RowMajorMatrix::new(trace, FieldArithmeticCols::::NUM_COLS) + RowMajorMatrix::new(trace, FieldArithmeticCols::::NUM_COLS) } } diff --git a/vm/src/memory/mod.rs b/vm/src/memory/mod.rs index 67a7bfe6b5..679bd72234 100644 --- a/vm/src/memory/mod.rs +++ b/vm/src/memory/mod.rs @@ -8,11 +8,11 @@ pub enum OpType { Write = 1, } -#[derive(Clone, Debug)] -pub struct MemoryAccess { +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct MemoryAccess { pub timestamp: usize, pub op_type: OpType, pub address_space: F, pub address: F, - pub data: Vec, + pub data: [F; WORD_SIZE], } diff --git a/vm/src/memory/offline_checker/air.rs b/vm/src/memory/offline_checker/air.rs index 619cb337dc..2f7ac91708 100644 --- a/vm/src/memory/offline_checker/air.rs +++ b/vm/src/memory/offline_checker/air.rs @@ -16,17 +16,17 @@ use afs_chips::{ sub_chip::{AirConfig, SubAir}, }; -impl AirConfig for OfflineChecker { +impl AirConfig for OfflineChecker { type Cols = OfflineCheckerCols; } -impl BaseAir for OfflineChecker { +impl BaseAir for OfflineChecker { fn width(&self) -> usize { self.air_width() } } -impl Air for OfflineChecker +impl Air for OfflineChecker where AB::M: Clone, { @@ -105,7 +105,7 @@ where next_cols.is_equal_data_aux.prods, next_cols.is_equal_data_aux.invs, ); - let is_equal_data_air = IsEqualVecAir::new(self.data_len); + let is_equal_data_air = IsEqualVecAir::new(WORD_SIZE); SubAir::eval( &is_equal_data_air, diff --git a/vm/src/memory/offline_checker/bridge.rs b/vm/src/memory/offline_checker/bridge.rs index c3fd79c539..9bcb8c4c77 100644 --- a/vm/src/memory/offline_checker/bridge.rs +++ b/vm/src/memory/offline_checker/bridge.rs @@ -12,7 +12,7 @@ use afs_chips::is_less_than_tuple::columns::{IsLessThanTupleCols, IsLessThanTupl use afs_chips::is_less_than_tuple::IsLessThanTupleAir; use afs_chips::sub_chip::SubAirBridge; -impl SubAirBridge for OfflineChecker { +impl SubAirBridge for OfflineChecker { /// Receives operations (clk, op_type, addr_space, pointer, data) fn receives(&self, col_indices: OfflineCheckerCols) -> Vec> { let op_cols: Vec> = iter::once(col_indices.clk) @@ -50,7 +50,7 @@ impl SubAirBridge for OfflineChecker { } } -impl AirBridge for OfflineChecker { +impl AirBridge for OfflineChecker { fn receives(&self) -> Vec> { let num_cols = self.air_width(); let all_cols = (0..num_cols).collect::>(); diff --git a/vm/src/memory/offline_checker/columns.rs b/vm/src/memory/offline_checker/columns.rs index 2b83cf6329..bddf4f223f 100644 --- a/vm/src/memory/offline_checker/columns.rs +++ b/vm/src/memory/offline_checker/columns.rs @@ -97,7 +97,7 @@ where flattened } - pub fn from_slice(slc: &[T], oc: &OfflineChecker) -> Self { + pub fn from_slice(slc: &[T], oc: &OfflineChecker) -> Self { assert!(slc.len() == oc.air_width()); let mem_width = oc.mem_width(); @@ -114,11 +114,11 @@ where is_equal_addr_space_aux: IsEqualAuxCols::from_slice(&slc[8 + mem_width..9 + mem_width]), is_equal_pointer_aux: IsEqualAuxCols::from_slice(&slc[9 + mem_width..10 + mem_width]), is_equal_data_aux: IsEqualVecAuxCols::from_slice( - &slc[10 + mem_width..10 + mem_width + 2 * oc.data_len], - oc.data_len, + &slc[10 + mem_width..10 + mem_width + 2 * WORD_SIZE], + WORD_SIZE, ), lt_aux: IsLessThanTupleAuxCols::from_slice( - &slc[10 + mem_width + 2 * oc.data_len..], + &slc[10 + mem_width + 2 * WORD_SIZE..], oc.addr_clk_limb_bits.clone(), oc.decomp, 3, diff --git a/vm/src/memory/offline_checker/mod.rs b/vm/src/memory/offline_checker/mod.rs index 19b5acd899..445a0110e0 100644 --- a/vm/src/memory/offline_checker/mod.rs +++ b/vm/src/memory/offline_checker/mod.rs @@ -1,43 +1,103 @@ +use std::{array::from_fn, collections::HashMap}; + use afs_chips::is_less_than_tuple::columns::IsLessThanTupleAuxCols; +use p3_field::PrimeField32; + +use crate::memory::OpType; + +use super::MemoryAccess; mod air; mod bridge; mod columns; mod trace; -pub struct OfflineChecker { - data_len: usize, +pub struct OfflineChecker { addr_clk_limb_bits: Vec, decomp: usize, } -impl OfflineChecker { +impl OfflineChecker { + pub fn mem_width(&self) -> usize { + // 1 for addr_space, 1 for pointer, data_len for data + 2 + WORD_SIZE + } + + pub fn air_width(&self) -> usize { + 10 + self.mem_width() + + 2 * WORD_SIZE + + IsLessThanTupleAuxCols::::get_width( + self.addr_clk_limb_bits.clone(), + self.decomp, + 3, + ) + } +} + +pub struct MemoryChip { + pub air: OfflineChecker, + pub accesses: Vec>, + memory: HashMap<(F, F), F>, + last_timestamp: Option, +} + +impl MemoryChip { pub fn new( - data_len: usize, addr_space_limb_bits: usize, pointer_limb_bits: usize, clk_limb_bits: usize, decomp: usize, ) -> Self { Self { - data_len, - addr_clk_limb_bits: vec![addr_space_limb_bits, pointer_limb_bits, clk_limb_bits], - decomp, + air: OfflineChecker { + addr_clk_limb_bits: vec![addr_space_limb_bits, pointer_limb_bits, clk_limb_bits], + decomp, + }, + accesses: vec![], + memory: HashMap::new(), + last_timestamp: None, } } - pub fn mem_width(&self) -> usize { - // 1 for addr_space, 1 for pointer, data_len for data - 2 + self.data_len + pub fn read_word(&mut self, timestamp: usize, address_space: F, address: F) -> [F; WORD_SIZE] { + assert!(address_space != F::zero()); + if let Some(last_timestamp) = self.last_timestamp { + assert!(timestamp > last_timestamp); + } + self.last_timestamp = Some(timestamp); + let data = from_fn(|i| self.memory[&(address_space, address + F::from_canonical_usize(i))]); + self.accesses.push(MemoryAccess { + timestamp, + op_type: OpType::Read, + address_space, + address, + data, + }); + data } - pub fn air_width(&self) -> usize { - 10 + self.mem_width() - + 2 * self.data_len - + IsLessThanTupleAuxCols::::get_width( - self.addr_clk_limb_bits.clone(), - self.decomp, - 3, - ) + pub fn write_word( + &mut self, + timestamp: usize, + address_space: F, + address: F, + data: [F; WORD_SIZE], + ) { + assert!(address_space != F::zero()); + if let Some(last_timestamp) = self.last_timestamp { + assert!(timestamp > last_timestamp); + } + self.last_timestamp = Some(timestamp); + for (i, &datum) in data.iter().enumerate() { + self.memory + .insert((address_space, address + F::from_canonical_usize(i)), datum); + } + self.accesses.push(MemoryAccess { + timestamp, + op_type: OpType::Write, + address_space, + address, + data, + }); } } diff --git a/vm/src/memory/offline_checker/trace.rs b/vm/src/memory/offline_checker/trace.rs index fefdaa1c75..0074ebb741 100644 --- a/vm/src/memory/offline_checker/trace.rs +++ b/vm/src/memory/offline_checker/trace.rs @@ -7,13 +7,13 @@ use p3_matrix::dense::RowMajorMatrix; use crate::memory::{MemoryAccess, OpType}; -use super::OfflineChecker; +use super::MemoryChip; use afs_chips::is_equal_vec::IsEqualVecAir; use afs_chips::is_less_than_tuple::IsLessThanTupleAir; use afs_chips::range_gate::RangeCheckerGateChip; use afs_chips::sub_chip::LocalTraceInstructions; -impl OfflineChecker { +impl MemoryChip { /// Each row in the trace follow the same order as the Cols struct: /// [clk, mem_row, op_type, same_addr_space, same_pointer, same_addr, same_data, lt_bit, is_valid, is_equal_addr_space_aux, is_equal_pointer_aux, is_equal_data_aux, lt_aux] /// @@ -21,13 +21,12 @@ impl OfflineChecker { /// The trace is sorted by addr (addr_space and pointer) and then by clk, so every addr has a block of consective rows in the trace with the following structure /// A row is added to the trace for every read/write operation with the corresponding data /// The trace is padded at the end to be of height trace_degree - pub fn generate_trace( - &self, - mut ops: Vec>, + pub fn generate_trace( + &mut self, range_checker: Arc, - trace_degree: usize, ) -> RowMajorMatrix { - ops.sort_by_key(|op| (op.address_space, op.address, op.timestamp)); + self.accesses + .sort_by_key(|op| (op.address_space, op.address, op.timestamp)); let mut rows: Vec = vec![]; @@ -36,44 +35,44 @@ impl OfflineChecker { op_type: OpType::Read, address_space: F::zero(), address: F::zero(), - data: vec![F::zero(); self.data_len], + data: [F::zero(); WORD_SIZE], }; - if !ops.is_empty() { - rows.extend(self.generate_trace_row::( + if !self.accesses.is_empty() { + rows.extend(self.generate_trace_row( true, 1, - &ops[0], + &self.accesses[0], &dummy_op, range_checker.clone(), )); } - for i in 1..ops.len() { - rows.extend(self.generate_trace_row::( + for i in 1..self.accesses.len() { + rows.extend(self.generate_trace_row( false, 1, - &ops[i], - &ops[i - 1], + &self.accesses[i], + &self.accesses[i - 1], range_checker.clone(), )); } // Ensure that trace degree is a power of two - assert!(trace_degree > 0 && trace_degree & (trace_degree - 1) == 0); + let trace_degree = self.accesses.len().next_power_of_two(); - if ops.len() < trace_degree { - rows.extend(self.generate_trace_row::( + if self.accesses.len() < trace_degree { + rows.extend(self.generate_trace_row( false, 0, &dummy_op, - &ops[ops.len() - 1], + &self.accesses[self.accesses.len() - 1], range_checker.clone(), )); } - for _i in 1..(trace_degree - ops.len()) { - rows.extend(self.generate_trace_row::( + for _i in 1..(trace_degree - self.accesses.len()) { + rows.extend(self.generate_trace_row( false, 0, &dummy_op, @@ -82,15 +81,15 @@ impl OfflineChecker { )); } - RowMajorMatrix::new(rows, self.air_width()) + RowMajorMatrix::new(rows, self.air.air_width()) } - pub fn generate_trace_row( + pub fn generate_trace_row( &self, is_first_row: bool, is_valid: u8, - curr_op: &MemoryAccess, - prev_op: &MemoryAccess, + curr_op: &MemoryAccess, + prev_op: &MemoryAccess, range_checker: Arc, ) -> Vec { let mut row: Vec = vec![]; @@ -99,7 +98,7 @@ impl OfflineChecker { row.push(F::from_canonical_usize(curr_op.timestamp)); row.push(curr_op.address_space); row.push(curr_op.address); - row.extend(curr_op.data.clone()); + row.extend(curr_op.data); row.push(F::from_canonical_u8(op_type)); let same_addr_space = if curr_op.address_space == prev_op.address_space { @@ -136,11 +135,11 @@ impl OfflineChecker { let is_equal_addr_space_air = IsEqualAir {}; let is_equal_pointer_air = IsEqualAir {}; - let is_equal_data_air = IsEqualVecAir::new(self.data_len); + let is_equal_data_air = IsEqualVecAir::new(WORD_SIZE); let lt_air = IsLessThanTupleAir::new( range_checker.bus_index(), - self.addr_clk_limb_bits.clone(), - self.decomp, + self.air.addr_clk_limb_bits.clone(), + self.air.decomp, ); let is_equal_addr_space_aux = is_equal_addr_space_air @@ -150,8 +149,8 @@ impl OfflineChecker { .generate_trace_row((prev_op.address, curr_op.address)) .flatten()[3]; let is_equal_data_aux = is_equal_data_air - .generate_trace_row((prev_op.data.clone(), curr_op.data.clone())) - .flatten()[2 * self.data_len..] + .generate_trace_row((prev_op.data.to_vec(), curr_op.data.to_vec())) + .flatten()[2 * WORD_SIZE..] .to_vec(); let lt_aux: Vec = lt_air .generate_trace_row(( @@ -175,17 +174,19 @@ impl OfflineChecker { row.extend(is_equal_data_aux); row.extend(lt_aux); + let mem_width = self.air.mem_width(); + if is_first_row { // same_addr_space should be 0 - row[2 + self.mem_width()] = F::zero(); + row[2 + mem_width] = F::zero(); // same_pointer should be 0 - row[3 + self.mem_width()] = F::zero(); + row[3 + mem_width] = F::zero(); // same_addr should be 0 - row[4 + self.mem_width()] = F::zero(); + row[4 + mem_width] = F::zero(); // same_data should be 0 - row[5 + self.mem_width()] = F::zero(); + row[5 + mem_width] = F::zero(); // lt_bit should be 1 - row[6 + self.mem_width()] = F::one(); + row[6 + mem_width] = F::one(); } row diff --git a/vm/src/memory/tests.rs b/vm/src/memory/tests.rs index e9ac1055cf..533b7152a4 100644 --- a/vm/src/memory/tests.rs +++ b/vm/src/memory/tests.rs @@ -12,9 +12,9 @@ use p3_matrix::dense::RowMajorMatrix; use crate::cpu::{MEMORY_BUS, RANGE_CHECKER_BUS}; -use super::{offline_checker::OfflineChecker, MemoryAccess, OpType}; +use super::{offline_checker::MemoryChip, MemoryAccess, OpType}; -const DATA_LEN: usize = 3; +const WORD_SIZE: usize = 3; const ADDR_SPACE_LIMB_BITS: usize = 8; const POINTER_LIMB_BITS: usize = 8; const CLK_LIMB_BITS: usize = 8; @@ -26,22 +26,21 @@ const TRACE_DEGREE: usize = 16; #[test] fn test_offline_checker() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let offline_checker = OfflineChecker::new( - DATA_LEN, + let mut chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + offline_checker.mem_width(), true, MEMORY_BUS); + let requester = DummyInteractionAir::new(2 + chip.air.mem_width(), true, MEMORY_BUS); - let ops: Vec> = vec![ + let ops: Vec> = vec![ MemoryAccess { timestamp: 1, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(232), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -50,9 +49,9 @@ fn test_offline_checker() { MemoryAccess { timestamp: 0, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(2324), BabyBear::from_canonical_usize(433), BabyBear::from_canonical_usize(1778), @@ -63,7 +62,7 @@ fn test_offline_checker() { op_type: OpType::Write, address_space: BabyBear::one(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(231), BabyBear::from_canonical_usize(3883), BabyBear::from_canonical_usize(17), @@ -72,9 +71,9 @@ fn test_offline_checker() { MemoryAccess { timestamp: 2, op_type: OpType::Read, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(232), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -85,7 +84,7 @@ fn test_offline_checker() { op_type: OpType::Read, address_space: BabyBear::two(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(4382), BabyBear::from_canonical_usize(8837), BabyBear::from_canonical_usize(192), @@ -96,7 +95,7 @@ fn test_offline_checker() { op_type: OpType::Write, address_space: BabyBear::two(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(4382), BabyBear::from_canonical_usize(8837), BabyBear::from_canonical_usize(192), @@ -105,22 +104,37 @@ fn test_offline_checker() { MemoryAccess { timestamp: 3, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(3243), BabyBear::from_canonical_usize(3214), BabyBear::from_canonical_usize(6639), ], }, ]; + let mut ops_sorted = ops.clone(); + ops_sorted.sort_by_key(|op| op.timestamp); - let offline_checker_trace = - offline_checker.generate_trace(ops.clone(), range_checker.clone(), TRACE_DEGREE); + for op in ops_sorted.iter() { + match op.op_type { + OpType::Read => { + assert_eq!( + chip.read_word(op.timestamp, op.address_space, op.address), + op.data + ); + } + OpType::Write => { + chip.write_word(op.timestamp, op.address_space, op.address, op.data); + } + } + } + + let trace = chip.generate_trace(range_checker.clone()); let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( ops.iter() - .flat_map(|op: &MemoryAccess| { + .flat_map(|op: &MemoryAccess| { [ BabyBear::one(), BabyBear::from_canonical_usize(op.timestamp), @@ -143,8 +157,8 @@ fn test_offline_checker() { ); run_simple_test_no_pis( - vec![&offline_checker, &range_checker.air, &requester], - vec![offline_checker_trace, range_checker_trace, requester_trace], + vec![&chip.air, &range_checker.air, &requester], + vec![trace, range_checker_trace, requester_trace], ) .expect("Verification failed"); } @@ -152,34 +166,30 @@ fn test_offline_checker() { #[test] fn test_offline_checker_negative_invalid_read() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let offline_checker = OfflineChecker::new( - DATA_LEN, + let mut memory_chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + offline_checker.mem_width(), true, MEMORY_BUS); + let requester = DummyInteractionAir::new(2 + memory_chip.air.mem_width(), true, MEMORY_BUS); // should fail because we can't read before writing - let ops: Vec> = vec![MemoryAccess { - timestamp: 0, - op_type: OpType::Read, - address_space: BabyBear::zero(), - address: BabyBear::zero(), - data: vec![ - BabyBear::from_canonical_usize(0), - BabyBear::from_canonical_usize(0), - BabyBear::from_canonical_usize(0), - ], - }]; + memory_chip.write_word( + 0, + BabyBear::one(), + BabyBear::zero(), + [BabyBear::zero(), BabyBear::zero(), BabyBear::zero()], + ); + memory_chip.accesses[0].op_type = OpType::Read; - let offline_checker_trace = - offline_checker.generate_trace(ops.clone(), range_checker.clone(), TRACE_DEGREE); + let memory_trace = memory_chip.generate_trace(range_checker.clone()); let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( - ops.iter() - .flat_map(|op: &MemoryAccess| { + memory_chip + .accesses + .iter() + .flat_map(|op: &MemoryAccess| { iter::once(BabyBear::one()) .chain(iter::once(BabyBear::from_canonical_usize(op.timestamp))) .chain(iter::once(BabyBear::from_canonical_u8(op.op_type as u8))) @@ -191,7 +201,7 @@ fn test_offline_checker_negative_invalid_read() { iter::repeat_with(|| { iter::repeat(BabyBear::zero()).take(1 + requester.field_width()) }) - .take(TRACE_DEGREE - ops.len()) + .take(TRACE_DEGREE - memory_chip.accesses.len()) .flatten(), ) .collect(), @@ -203,8 +213,8 @@ fn test_offline_checker_negative_invalid_read() { }); assert_eq!( run_simple_test_no_pis( - vec![&offline_checker, &range_checker.air, &requester], - vec![offline_checker_trace, range_checker_trace, requester_trace], + vec![&memory_chip.air, &range_checker.air, &requester], + vec![memory_trace, range_checker_trace, requester_trace], ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" @@ -214,22 +224,21 @@ fn test_offline_checker_negative_invalid_read() { #[test] fn test_offline_checker_negative_data_mismatch() { let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, RANGE_MAX)); - let offline_checker = OfflineChecker::new( - DATA_LEN, + let mut chip = MemoryChip::new( ADDR_SPACE_LIMB_BITS, POINTER_LIMB_BITS, CLK_LIMB_BITS, DECOMP, ); - let requester = DummyInteractionAir::new(2 + offline_checker.mem_width(), true, MEMORY_BUS); + let requester = DummyInteractionAir::new(2 + chip.air.mem_width(), true, MEMORY_BUS); - let ops: Vec> = vec![ + let ops: Vec> = vec![ MemoryAccess { timestamp: 0, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::zero(), - data: vec![ + data: [ BabyBear::from_canonical_usize(2324), BabyBear::from_canonical_usize(433), BabyBear::from_canonical_usize(1778), @@ -238,9 +247,9 @@ fn test_offline_checker_negative_data_mismatch() { MemoryAccess { timestamp: 1, op_type: OpType::Write, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(232), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -250,9 +259,9 @@ fn test_offline_checker_negative_data_mismatch() { MemoryAccess { timestamp: 2, op_type: OpType::Read, - address_space: BabyBear::zero(), + address_space: BabyBear::one(), address: BabyBear::one(), - data: vec![ + data: [ BabyBear::from_canonical_usize(233), BabyBear::from_canonical_usize(888), BabyBear::from_canonical_usize(5954), @@ -260,12 +269,14 @@ fn test_offline_checker_negative_data_mismatch() { }, ]; - let offline_checker_trace = - offline_checker.generate_trace(ops.clone(), range_checker.clone(), TRACE_DEGREE); + chip.accesses.clone_from(&ops); + + let trace = chip.generate_trace(range_checker.clone()); + let range_checker_trace = range_checker.generate_trace(); let requester_trace = RowMajorMatrix::new( ops.iter() - .flat_map(|op: &MemoryAccess| { + .flat_map(|op: &MemoryAccess| { iter::once(BabyBear::one()) .chain(iter::once(BabyBear::from_canonical_usize(op.timestamp))) .chain(iter::once(BabyBear::from_canonical_u8(op.op_type as u8))) @@ -289,8 +300,8 @@ fn test_offline_checker_negative_data_mismatch() { }); assert_eq!( run_simple_test_no_pis( - vec![&offline_checker, &range_checker.air, &requester], - vec![offline_checker_trace, range_checker_trace, requester_trace], + vec![&chip.air, &range_checker.air, &requester], + vec![trace, range_checker_trace, requester_trace], ), Err(VerificationError::OodEvaluationMismatch), "Expected verification to fail, but it passed" diff --git a/vm/src/program/mod.rs b/vm/src/program/mod.rs index 0abb2d2a48..00531822e0 100644 --- a/vm/src/program/mod.rs +++ b/vm/src/program/mod.rs @@ -11,24 +11,28 @@ pub mod bridge; pub mod columns; pub mod trace; -pub struct ProgramAir { - pub program: Vec>, +pub struct ProgramAir { + pub program: Vec>, } -impl ProgramAir { +pub struct ProgramChip { + pub air: ProgramAir, + pub execution_frequencies: Vec, +} + +impl ProgramChip { pub fn new(mut program: Vec>) -> Self { - // in order to make program length a power of 2, - // add instructions that jump to themselves - // so that any program that tries to jump to instructions that shouldn't exist - // will enter an infinite loop (so their termination cannot be proven) while !program.len().is_power_of_two() { - // op_c, as_c never matter in JAL - // op_a doesn't matter here (random address to write garbage to) - // op_b is the offset, needs to be 0 so we jump to self - // as_b should be nonzero so we don't write to immediate (may be unsupported) program.push(Instruction::from_isize(FAIL, 0, 0, 0, 0, 0)); } + Self { + execution_frequencies: vec![0; program.len()], + air: ProgramAir { program }, + } + } - Self { program } + pub fn get_instruction(&mut self, pc: usize) -> Instruction { + self.execution_frequencies[pc] += 1; + self.air.program[pc] } } diff --git a/vm/src/program/tests/mod.rs b/vm/src/program/tests/mod.rs index e17fe15675..b77426adfd 100644 --- a/vm/src/program/tests/mod.rs +++ b/vm/src/program/tests/mod.rs @@ -5,12 +5,11 @@ use p3_field::AbstractField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use crate::cpu::{CpuAir, CpuOptions, READ_INSTRUCTION_BUS}; +use crate::cpu::READ_INSTRUCTION_BUS; use crate::cpu::{trace::Instruction, OpCode::*}; use crate::program::columns::ProgramPreprocessedCols; - -use super::ProgramAir; +use crate::program::ProgramChip; #[test] fn test_flatten_fromslice_roundtrip() { @@ -27,20 +26,20 @@ fn test_flatten_fromslice_roundtrip() { assert_eq!(num_cols, flattened.len()); } -fn interaction_test(is_field_arithmetic_enabled: bool, program: Vec>) { - let cpu_air = CpuAir::<1>::new(CpuOptions { - field_arithmetic_enabled: is_field_arithmetic_enabled, - }); - let execution = cpu_air.generate_program_execution(program.clone()).unwrap(); - - let air = ProgramAir::new(program); - let trace = air.generate_trace(&execution); +fn interaction_test(program: Vec>, execution: Vec) { + let mut chip = ProgramChip::new(program.clone()); + let mut execution_frequencies = vec![0; program.len()]; + for pc in execution { + execution_frequencies[pc] += 1; + chip.get_instruction(pc); + } + let trace = chip.generate_trace(); let counter_air = DummyInteractionAir::new(7, true, READ_INSTRUCTION_BUS); let mut program_rows = vec![]; - for (pc, instruction) in execution.program.iter().enumerate() { + for (pc, instruction) in program.iter().enumerate() { program_rows.extend(vec![ - execution.execution_frequencies[pc], + BabyBear::from_canonical_usize(execution_frequencies[pc]), BabyBear::from_canonical_usize(pc), BabyBear::from_canonical_usize(instruction.opcode as usize), instruction.op_a, @@ -57,7 +56,7 @@ fn interaction_test(is_field_arithmetic_enabled: bool, program: Vec::new(CpuOptions { - field_arithmetic_enabled: true, - }); - let execution = cpu_air.generate_program_execution(program.clone()).unwrap(); - - let air = ProgramAir { program }; - let trace = air.generate_trace(&execution); + let mut chip = ProgramChip::new(program.clone()); + let execution_frequencies = vec![1; program.len()]; + for pc in 0..program.len() { + chip.get_instruction(pc); + } + let trace = chip.generate_trace(); let counter_air = DummyInteractionAir::new(7, true, READ_INSTRUCTION_BUS); let mut program_rows = vec![]; - for (pc, instruction) in execution.program.iter().enumerate() { + for (pc, instruction) in program.iter().enumerate() { program_rows.extend(vec![ - execution.execution_frequencies[pc], + BabyBear::from_canonical_usize(execution_frequencies[pc]), BabyBear::from_canonical_usize(pc), BabyBear::from_canonical_usize(instruction.opcode as usize), instruction.op_a, @@ -139,6 +135,6 @@ fn test_program_negative() { let mut counter_trace = RowMajorMatrix::new(program_rows, 8); counter_trace.row_mut(1)[1] = BabyBear::zero(); - run_simple_test_no_pis(vec![&air, &counter_air], vec![trace, counter_trace]) + run_simple_test_no_pis(vec![&chip.air, &counter_air], vec![trace, counter_trace]) .expect("Incorrect failure mode"); } diff --git a/vm/src/program/trace.rs b/vm/src/program/trace.rs index 8b8e6700a9..cffd1c51f6 100644 --- a/vm/src/program/trace.rs +++ b/vm/src/program/trace.rs @@ -1,19 +1,15 @@ use p3_field::PrimeField64; use p3_matrix::dense::RowMajorMatrix; -use crate::cpu::trace::ProgramExecution; +use super::ProgramChip; -use super::ProgramAir; - -impl ProgramAir { - pub fn generate_trace( - &self, - execution: &ProgramExecution, - ) -> RowMajorMatrix { - let mut frequencies = execution.execution_frequencies.clone(); - while frequencies.len() != self.program.len() { - frequencies.push(F::zero()); - } - RowMajorMatrix::new_col(frequencies) +impl ProgramChip { + pub fn generate_trace(&self) -> RowMajorMatrix { + RowMajorMatrix::new_col( + self.execution_frequencies + .iter() + .map(|x| F::from_canonical_usize(*x)) + .collect::>(), + ) } } diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index b0223c334a..73faa6b3b7 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -2,52 +2,41 @@ use std::sync::Arc; use afs_chips::range_gate::RangeCheckerGateChip; use afs_stark_backend::rap::AnyRap; -use p3_field::{PrimeField32, PrimeField64}; +use p3_field::PrimeField32; use p3_matrix::{dense::DenseMatrix, Matrix}; use p3_uni_stark::{StarkGenericConfig, Val}; use p3_util::log2_strict_usize; +pub enum Void {} + use crate::{ cpu::{ trace::{ExecutionError, Instruction}, - CpuAir, RANGE_CHECKER_BUS, + CpuAir, CpuOptions, RANGE_CHECKER_BUS, }, - field_arithmetic::FieldArithmeticAir, - memory::{offline_checker::OfflineChecker, MemoryAccess}, - program::ProgramAir, + field_arithmetic::FieldArithmeticChip, + memory::offline_checker::MemoryChip, + program::ProgramChip, }; use self::config::{VmConfig, VmParamsConfig}; pub mod config; -pub struct VirtualMachine -where - Val: PrimeField64, -{ +pub struct VirtualMachine { pub config: VmParamsConfig, pub cpu_air: CpuAir, - pub program_air: ProgramAir>, - pub memory_air: OfflineChecker, - pub field_arithmetic_air: FieldArithmeticAir, + pub program_chip: ProgramChip, + pub memory_chip: MemoryChip, + pub field_arithmetic_chip: FieldArithmeticChip, pub range_checker: Arc, - pub cpu_trace: DenseMatrix>, - pub program_trace: DenseMatrix>, - pub memory_trace: DenseMatrix>, - pub field_arithmetic_trace: DenseMatrix>, - pub range_trace: DenseMatrix>, + traces: Vec>, } -impl VirtualMachine -where - Val: PrimeField32, -{ - pub fn new( - config: VmConfig, - program: Vec>>, - ) -> Result { +impl VirtualMachine { + pub fn new(config: VmConfig, program: Vec>) -> Self { let config = config.vm; let decomp = config.decomp; let limb_bits = config.limb_bits; @@ -55,83 +44,44 @@ where let range_checker = Arc::new(RangeCheckerGateChip::new(RANGE_CHECKER_BUS, 1 << decomp)); let cpu_air = CpuAir::new(config.cpu_options()); - let program_air = ProgramAir::new(program.clone()); - let memory_air = OfflineChecker::new(WORD_SIZE, limb_bits, limb_bits, limb_bits, decomp); - let field_arithmetic_air = FieldArithmeticAir::new(); - - let execution = cpu_air.generate_program_execution(program_air.program.clone())?; - let program_trace = program_air.generate_trace(&execution); + let program_chip = ProgramChip::new(program.clone()); + let memory_chip = MemoryChip::new(limb_bits, limb_bits, limb_bits, decomp); + let field_arithmetic_chip = FieldArithmeticChip::new(); - let ops = execution - .memory_accesses - .iter() - .map(|access| MemoryAccess { - address: access.address, - op_type: access.op_type, - address_space: access.address_space, - timestamp: access.timestamp, - data: access.data.to_vec(), - }) - .collect::>(); - let memory_trace_degree = execution.memory_accesses.len().next_power_of_two(); - let memory_trace = - memory_air.generate_trace(ops, range_checker.clone(), memory_trace_degree); - - let range_trace: DenseMatrix> = range_checker.generate_trace(); - - let field_arithmetic_trace = field_arithmetic_air.generate_trace(&execution); - - Ok(Self { + Self { config, cpu_air, - program_air, - memory_air, - field_arithmetic_air, + program_chip, + memory_chip, + field_arithmetic_chip, range_checker, - cpu_trace: execution.trace(config.cpu_options()), - program_trace, - memory_trace, - field_arithmetic_trace, - range_trace, - }) + traces: vec![], + } } - pub fn chips(&self) -> Vec<&dyn AnyRap> { - if self.config.field_arithmetic_enabled { - vec![ - &self.cpu_air, - &self.program_air, - &self.memory_air, - &self.field_arithmetic_air, - &self.range_checker.air, - ] - } else { - vec![ - &self.cpu_air, - &self.program_air, - &self.memory_air, - &self.range_checker.air, - ] + pub fn options(&self) -> CpuOptions { + self.config.cpu_options() + } + + fn generate_traces(&mut self) -> Result>, ExecutionError> { + let cpu_trace = CpuAir::generate_trace(self)?; + let mut result = vec![ + cpu_trace, + self.program_chip.generate_trace(), + self.memory_chip.generate_trace(self.range_checker.clone()), + self.range_checker.generate_trace(), + ]; + if self.options().field_arithmetic_enabled { + result.push(self.field_arithmetic_chip.generate_trace()); } + Ok(result) } - pub fn traces(&self) -> Vec>> { - if self.config.field_arithmetic_enabled { - vec![ - self.cpu_trace.clone(), - self.program_trace.clone(), - self.memory_trace.clone(), - self.field_arithmetic_trace.clone(), - self.range_trace.clone(), - ] - } else { - vec![ - self.cpu_trace.clone(), - self.program_trace.clone(), - self.memory_trace.clone(), - self.range_trace.clone(), - ] + pub fn traces(&mut self) -> Result>, ExecutionError> { + if self.traces.is_empty() { + self.traces = self.generate_traces()?; } + Ok(self.traces.clone()) } /*fn max_trace_heights(&self) -> Vec { @@ -150,11 +100,35 @@ where .collect() }*/ - pub fn max_log_degree(&self) -> usize { + pub fn max_log_degree(&mut self) -> Result { let mut checker_trace_degree = 0; - for trace in self.traces() { + for trace in self.traces()? { checker_trace_degree = std::cmp::max(checker_trace_degree, trace.height()); } - log2_strict_usize(checker_trace_degree) + Ok(log2_strict_usize(checker_trace_degree)) + } +} + +pub fn get_chips( + vm: &VirtualMachine>, +) -> Vec<&dyn AnyRap> +where + Val: PrimeField32, +{ + if vm.options().field_arithmetic_enabled { + vec![ + &vm.cpu_air, + &vm.program_chip.air, + &vm.memory_chip.air, + &vm.range_checker.air, + &vm.field_arithmetic_chip.air, + ] + } else { + vec![ + &vm.cpu_air, + &vm.program_chip.air, + &vm.memory_chip.air, + &vm.range_checker.air, + ] } } diff --git a/vm/tests/integration_test.rs b/vm/tests/integration_test.rs index bccb7afb39..643bd4581f 100644 --- a/vm/tests/integration_test.rs +++ b/vm/tests/integration_test.rs @@ -5,6 +5,7 @@ use stark_vm::cpu::trace::Instruction; use stark_vm::cpu::OpCode::*; use stark_vm::vm::config::VmConfig; use stark_vm::vm::config::VmParamsConfig; +use stark_vm::vm::get_chips; use stark_vm::vm::VirtualMachine; const WORD_SIZE: usize = 1; @@ -12,7 +13,7 @@ const LIMB_BITS: usize = 8; const DECOMP: usize = 4; fn air_test(field_arithmetic_enabled: bool, program: Vec>) { - let vm = VirtualMachine::::new( + let mut vm = VirtualMachine::::new( VmConfig { vm: VmParamsConfig { field_arithmetic_enabled, @@ -21,10 +22,9 @@ fn air_test(field_arithmetic_enabled: bool, program: Vec>) }, }, program, - ) - .unwrap(); - let chips = vm.chips(); - let traces = vm.traces(); + ); + let traces = vm.traces().unwrap(); + let chips = get_chips(&vm); run_simple_test_no_pis(chips, traces).expect("Verification failed"); }