diff --git a/core/src/air/bool.rs b/core/src/air/bool.rs index 4df25f4336..e922c575d0 100644 --- a/core/src/air/bool.rs +++ b/core/src/air/bool.rs @@ -1,10 +1,12 @@ +use core::borrow::{Borrow, BorrowMut}; use p3_air::AirBuilder; use p3_field::AbstractField; use super::AirVariable; +use valida_derive::AlignedBorrow; /// An AIR representation of a boolean value. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default, AlignedBorrow)] pub struct Bool(pub T); impl AirVariable for Bool { diff --git a/core/src/air/word.rs b/core/src/air/word.rs index 94fd63f1ee..55bd4dfb3c 100644 --- a/core/src/air/word.rs +++ b/core/src/air/word.rs @@ -1,6 +1,9 @@ use std::ops::{Index, IndexMut}; +use core::borrow::{Borrow, BorrowMut}; use p3_air::AirBuilder; +use p3_field::Field; +use valida_derive::AlignedBorrow; use super::AirVariable; @@ -8,7 +11,7 @@ use super::AirVariable; const WORD_LEN: usize = 4; /// An AIR representation of a word in the instruction set. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, AlignedBorrow)] pub struct Word(pub [T; WORD_LEN]); impl AirVariable for Word { @@ -34,3 +37,16 @@ impl IndexMut for Word { &mut self.0[index] } } + +impl From for Word { + fn from(value: u32) -> Self { + let inner = value + .to_le_bytes() + .iter() + .map(|v| F::from_canonical_u8(*v)) + .collect::>() + .try_into() + .unwrap(); + Word(inner) + } +} diff --git a/core/src/cpu/air.rs b/core/src/cpu/air.rs index 4ecc157ac5..07e743d12e 100644 --- a/core/src/cpu/air.rs +++ b/core/src/cpu/air.rs @@ -7,26 +7,12 @@ use p3_field::AbstractField; use p3_field::{Field, PrimeField}; use p3_matrix::MatrixRowSlices; use p3_util::indices_arr; +use valida_derive::AlignedBorrow; -#[derive(Debug, Clone, Copy)] -pub struct CpuAir; - -/// An AIR table for memory accesses. -#[derive(Debug, Clone)] -pub struct CpuCols { - /// The clock cycle value. - pub clk: T, - /// The program counter value. - pub pc: T, - /// The opcode for this cycle. - pub opcode: T, - /// The first operand for this instruction. - pub op_a: T, - /// The second operand for this instruction. - pub op_b: T, - /// The third operand for this instruction. - pub op_c: T, - // Whether op_b is an immediate value. +#[derive(AlignedBorrow, Default)] +#[repr(C)] +pub struct OpcodeSelectors { + // // Whether op_b is an immediate value. pub imm_b: T, // Whether op_c is an immediate value. pub imm_c: T, @@ -46,24 +32,51 @@ pub struct CpuCols { pub system_instruction: T, // Whether this is a multiply instruction. pub multiply_instruction: T, - // Selectors for load/store instructions and their types. - pub byte: Bool, - pub half: Bool, - pub word: Bool, - pub unsigned: Bool, - // TODO: we might need a selector for "MULSU" since no other instruction has "SU" + // // Selectors for load/store instructions and their types. + pub byte: T, + pub half: T, + pub word: T, + pub unsigned: T, + // // TODO: we might need a selector for "MULSU" since no other instruction has "SU" pub JALR: T, pub JAL: T, pub AUIPC: T, + // // Whether this instruction is reading from register A. + pub reg_a_read: T, +} + +#[derive(AlignedBorrow, Default)] +#[repr(C)] +pub struct InstructionCols { + // /// The opcode for this cycle. + pub opcode: T, + // /// The first operand for this instruction. + pub op_a: T, + // /// The second operand for this instruction. + pub op_b: T, + // /// The third operand for this instruction. + pub op_c: T, +} - // Operand values, either from registers or immediate values. +/// An AIR table for memory accesses. +#[derive(AlignedBorrow, Default)] +#[repr(C)] +pub struct CpuCols { + /// The clock cycle value. + pub clk: T, + // /// The program counter value. + pub pc: T, + + // Columns related to the instruction. + pub instruction: InstructionCols, + // Selectors for the opcode. + pub selectors: OpcodeSelectors, + + // // Operand values, either from registers or immediate values. pub op_a_val: Word, pub op_b_val: Word, pub op_c_val: Word, - // Whether this instruction is reading from register A. - pub reg_a_read: T, - // An addr that we are reading from or writing to. pub addr: Word, // The associated memory value for `addr`. @@ -81,28 +94,6 @@ const fn make_col_map() -> CpuCols { unsafe { transmute::<[usize; NUM_CPU_COLS], CpuCols>(indices_arr) } } -impl Borrow> for [T] { - fn borrow(&self) -> &CpuCols { - // TODO: Double check if this is correct & consider making asserts debug-only. - let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; - assert!(prefix.is_empty(), "Data was not aligned"); - assert!(suffix.is_empty(), "Data was not aligned"); - assert_eq!(shorts.len(), 1); - &shorts[0] - } -} - -impl BorrowMut> for [T] { - fn borrow_mut(&mut self) -> &mut CpuCols { - // TODO: Double check if this is correct & consider making asserts debug-only. - let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; - assert!(prefix.is_empty(), "Data was not aligned"); - assert!(suffix.is_empty(), "Data was not aligned"); - assert_eq!(shorts.len(), 1); - &mut shorts[0] - } -} - impl AirConstraint for CpuCols { fn eval(&self, builder: &mut AB) { let main = builder.main(); @@ -120,15 +111,16 @@ impl AirConstraint for CpuCols { //// Constraint op_a_val, op_b_val, op_c_val // Constraint the op_b_val and op_c_val columns when imm_b and imm_c are true. builder - .when(local.imm_b) - .assert_eq(reduce::(local.op_b_val), local.op_b); + .when(local.selectors.imm_b) + .assert_eq(reduce::(local.op_b_val), local.instruction.op_b); builder - .when(local.imm_c) - .assert_eq(reduce::(local.op_c_val), local.op_c); + .when(local.selectors.imm_c) + .assert_eq(reduce::(local.op_c_val), local.instruction.op_c); // We only read from the first register if there is a store or branch instruction. In all other cases we write. - let reg_a_read = - local.store_instruction + local.branch_instruction + local.multiply_instruction; + let reg_a_read = local.selectors.store_instruction + + local.selectors.branch_instruction + + local.selectors.multiply_instruction; //// For r-type, i-type and multiply instructions, we must constraint by an "opcode-oracle" table // TODO: lookup (clk, op_a_val, op_b_val, op_c_val) in the "opcode-oracle" table with multiplicity (register_instruction + immediate_instruction + multiply_instruction) @@ -136,7 +128,7 @@ impl AirConstraint for CpuCols { //// For branch instructions // TODO: lookup (clk, branch_cond_val, op_a_val, op_b_val) in the "branch" table with multiplicity branch_instruction // Increment the pc by 4 + op_c_val * branch_cond_val where we interpret the first result as a bool that it is. - builder.when(local.branch_instruction).assert_eq( + builder.when(local.selectors.branch_instruction).assert_eq( local.pc + AB::F::from_canonical_u8(4) + reduce::(local.op_c_val) * local.branch_cond_val.0[0], @@ -144,23 +136,24 @@ impl AirConstraint for CpuCols { ); //// For jump instructions - builder.when(local.jump_instruction).assert_eq( + builder.when(local.selectors.jump_instruction).assert_eq( reduce::(local.op_a_val), local.pc + AB::F::from_canonical_u8(4), ); - builder.when(local.JAL).assert_eq( + builder.when(local.selectors.JAL).assert_eq( local.pc + AB::F::from_canonical_u8(4) + reduce::(local.op_b_val), next.pc, ); - builder - .when(local.JALR) - .assert_eq(reduce::(local.op_b_val) + local.op_c, next.pc); + builder.when(local.selectors.JALR).assert_eq( + reduce::(local.op_b_val) + local.instruction.op_c, + next.pc, + ); //// For system instructions //// Upper immediate instructions // lookup(clk, op_c_val, imm, 12) in SLT table with multiplicity AUIPC - builder.when(local.AUIPC).assert_eq( + builder.when(local.selectors.AUIPC).assert_eq( reduce::(local.op_a_val), reduce::(local.op_c_val) + local.pc, ); diff --git a/core/src/cpu/mod.rs b/core/src/cpu/mod.rs index 226bdd4851..95c89572f5 100644 --- a/core/src/cpu/mod.rs +++ b/core/src/cpu/mod.rs @@ -8,4 +8,7 @@ pub struct CpuEvent { pub pc: u32, pub instruction: Instruction, pub operands: [u32; 3], + pub addr: Option, + pub memory_value: Option, + pub branch_condition: Option, } diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index 9bf37d5a6c..2275c5aa23 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -1,4 +1,4 @@ -use super::air::{CpuCols, CPU_COL_MAP, NUM_CPU_COLS}; +use super::air::{CpuCols, InstructionCols, OpcodeSelectors, CPU_COL_MAP, NUM_CPU_COLS}; use super::CpuEvent; use crate::lookup::{Interaction, IsRead}; use core::mem::{size_of, transmute}; @@ -6,7 +6,7 @@ use p3_air::{AirBuilder, BaseAir, VirtualPairCol}; use crate::air::Word; use crate::runtime::chip::Chip; -use crate::runtime::{Opcode, Runtime}; +use crate::runtime::{Instruction, Opcode, Runtime}; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; @@ -38,32 +38,34 @@ impl Chip for CpuChip { // lookup (clk, op_a, op_a_val, is_read=reg_a_read) in the register table with multiplicity 1. interactions.push(Interaction::lookup_register( CPU_COL_MAP.clk, - CPU_COL_MAP.op_a, + CPU_COL_MAP.instruction.op_a, CPU_COL_MAP.op_a_val, - IsRead::Expr(VirtualPairCol::single_main(CPU_COL_MAP.reg_a_read)), + IsRead::Expr(VirtualPairCol::single_main( + CPU_COL_MAP.selectors.reg_a_read, + )), VirtualPairCol::constant(F::one()), )); // lookup (clk, op_c, op_c_val, is_read=true) in the register table with multiplicity 1-imm_c // lookup (clk, op_b, op_b_val, is_read=true) in the register table with multiplicity 1-imm_b interactions.push(Interaction::lookup_register( CPU_COL_MAP.clk, - CPU_COL_MAP.op_c, + CPU_COL_MAP.instruction.op_c, CPU_COL_MAP.op_c_val, IsRead::Bool(true), - VirtualPairCol::new_main(vec![(CPU_COL_MAP.imm_c, F::neg_one())], F::one()), // 1-imm_c + VirtualPairCol::new_main(vec![(CPU_COL_MAP.selectors.imm_c, F::neg_one())], F::one()), // 1-imm_c )); interactions.push(Interaction::lookup_register( CPU_COL_MAP.clk, - CPU_COL_MAP.op_b, + CPU_COL_MAP.instruction.op_b, CPU_COL_MAP.op_b_val, IsRead::Bool(true), - VirtualPairCol::new_main(vec![(CPU_COL_MAP.imm_b, F::neg_one())], F::one()), // 1-imm_b + VirtualPairCol::new_main(vec![(CPU_COL_MAP.selectors.imm_b, F::neg_one())], F::one()), // 1-imm_b )); interactions.push(Interaction::add( CPU_COL_MAP.op_a_val, CPU_COL_MAP.op_b_val, CPU_COL_MAP.op_c_val, - VirtualPairCol::single_main(CPU_COL_MAP.register_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.register_instruction), )); //// For both load and store instructions, we must constraint mem_val to be a lookup of [addr] @@ -74,7 +76,7 @@ impl Chip for CpuChip { CPU_COL_MAP.addr, CPU_COL_MAP.op_b_val, CPU_COL_MAP.op_c_val, - VirtualPairCol::single_main(CPU_COL_MAP.load_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.load_instruction), )); // To constraint mem_val, we lookup [addr] in the memory table // lookup (clk, addr, mem_val, is_read=true) in the memory table with multiplicity load_instruction @@ -83,7 +85,7 @@ impl Chip for CpuChip { CPU_COL_MAP.addr, CPU_COL_MAP.mem_val, IsRead::Bool(true), - VirtualPairCol::single_main(CPU_COL_MAP.load_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.load_instruction), )); // Now we must constraint mem_val and op_a_val // We bus this to a "match_word" table with a combination of s/u and h/b/w @@ -96,7 +98,7 @@ impl Chip for CpuChip { CPU_COL_MAP.addr, CPU_COL_MAP.op_a_val, CPU_COL_MAP.op_c_val, - VirtualPairCol::single_main(CPU_COL_MAP.store_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.store_instruction), )); // To constraint mem_val, we lookup [addr] in the memory table // lookup (clk, addr, mem_val, is_read=false) in the memory table with multiplicity store_instruction @@ -105,7 +107,7 @@ impl Chip for CpuChip { CPU_COL_MAP.addr, CPU_COL_MAP.mem_val, IsRead::Bool(false), - VirtualPairCol::single_main(CPU_COL_MAP.store_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.store_instruction), )); // Now we must constraint mem_val and op_b_val // TODO: lookup (clk, mem_val, op_b_val, byte, half, word, unsigned) in the "match_word" table with multiplicity store_instruction @@ -118,7 +120,7 @@ impl Chip for CpuChip { CPU_COL_MAP.addr, CPU_COL_MAP.mem_val, IsRead::Bool(true), - VirtualPairCol::single_main(CPU_COL_MAP.load_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.load_instruction), )); // Constraint the memory in the case of a store instruction. @@ -127,7 +129,7 @@ impl Chip for CpuChip { CPU_COL_MAP.addr, CPU_COL_MAP.mem_val, IsRead::Bool(false), - VirtualPairCol::single_main(CPU_COL_MAP.store_instruction), + VirtualPairCol::single_main(CPU_COL_MAP.selectors.store_instruction), )); interactions } @@ -144,61 +146,172 @@ impl CpuChip { let cols: &mut CpuCols = unsafe { transmute(&mut row) }; cols.clk = F::from_canonical_u32(event.clk); cols.pc = F::from_canonical_u32(event.pc); - println!("rows: {:?}", row); - cols.opcode = F::from_canonical_u32(event.instruction.opcode as u32); - cols.op_a = F::from_canonical_u32(event.instruction.a as u32); - cols.op_b = F::from_canonical_u32(event.instruction.b as u32); - cols.op_c = F::from_canonical_u32(event.instruction.c as u32); - // TODO: based on the instruction, populate the relevant flags. - match event.instruction.opcode { - Opcode::ADD | Opcode::SUB | Opcode::AND => {} - Opcode::ADDI | Opcode::ANDI => { + + self.populate_instruction(&mut cols.instruction, event.instruction); + self.populate_selectors(&mut cols.selectors, event.instruction.opcode); + + cols.op_a_val = event.operands[0].into(); + cols.op_b_val = event.operands[1].into(); + cols.op_c_val = event.operands[2].into(); + + self.populate_memory(cols, event); + self.populate_branch(cols, event); + row + } + + fn populate_instruction(&self, cols: &mut InstructionCols, instruction: Instruction) { + cols.opcode = F::from_canonical_u32(instruction.opcode as u32); + match instruction.opcode { + Opcode::LUI => { + // For LUI, we convert it to a SLL instruction with imm_b and imm_c turned on. + cols.opcode = F::from_canonical_u32(Opcode::SLL as u32); + assert_eq!(instruction.c as u32, 12); + } + Opcode::AUIPC => { + // For AUIPC, we set the 3rd operand to imm_b << 12. + assert_eq!(instruction.c as u32, instruction.b << 12); + } + _ => {} + } + cols.op_a = F::from_canonical_u32(instruction.a as u32); + cols.op_b = F::from_canonical_u32(instruction.b as u32); + cols.op_c = F::from_canonical_u32(instruction.c as u32); + } + + fn populate_selectors(&self, cols: &mut OpcodeSelectors, opcode: Opcode) { + match opcode { + // Register instructions + Opcode::ADD + | Opcode::SUB + | Opcode::XOR + | Opcode::OR + | Opcode::AND + | Opcode::SLL + | Opcode::SRL + | Opcode::SRA + | Opcode::SLT + | Opcode::SLTU => { + // For register instructions, neither imm_b or imm_c should be turned on. + cols.register_instruction = F::one(); + } + // Immediate instructions + Opcode::ADDI + | Opcode::XORI + | Opcode::ORI + | Opcode::ANDI + | Opcode::SLLI + | Opcode::SRLI + | Opcode::SRAI + | Opcode::SLTI + | Opcode::SLTIU => { + // For immediate instructions, imm_c should be turned on. + cols.imm_c = F::one(); + cols.immediate_instruction = F::one(); + } + // Load instructions + Opcode::LB | Opcode::LH | Opcode::LW | Opcode::LBU | Opcode::LHU => { + // For load instructions, imm_c should be turned on. + cols.imm_c = F::one(); + cols.load_instruction = F::one(); + match opcode { + Opcode::LB | Opcode::LBU => { + cols.byte = F::one(); + } + Opcode::LH | Opcode::LHU => { + cols.half = F::one(); + } + Opcode::LW => { + cols.word = F::one(); + } + _ => {} + } + } + // Store instructions + Opcode::SB | Opcode::SH | Opcode::SW => { + // For store instructions, imm_c should be turned on. cols.imm_c = F::one(); + cols.store_instruction = F::one(); + cols.reg_a_read = F::one(); + match opcode { + Opcode::SB => { + cols.byte = F::one(); + } + Opcode::SH => { + cols.half = F::one(); + } + Opcode::SW => { + cols.word = F::one(); + } + _ => {} + } } + // Branch instructions + Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => { + cols.imm_c = F::one(); + cols.branch_instruction = F::one(); + cols.reg_a_read = F::one(); + } + // Jump instructions Opcode::JAL => { cols.JAL = F::one(); cols.imm_b = F::one(); + cols.imm_c = F::one(); cols.jump_instruction = F::one(); } Opcode::JALR => { cols.JALR = F::one(); + cols.imm_c = F::one(); cols.jump_instruction = F::one(); } + // Upper immediate instructions + Opcode::LUI => { + // Note that we convert a LUI opcode to a SLL opcode with both imm_b and imm_c turned on. + // And the value of imm_c is 12. + cols.imm_b = F::one(); + cols.imm_c = F::one(); + // In order to process lookups for the SLL opcode table, we'll also turn on the "immediate_instruction". + cols.immediate_instruction = F::one(); + } Opcode::AUIPC => { + // Note that for an AUIPC opcode, we turn on both imm_b and imm_c. cols.imm_b = F::one(); + cols.imm_c = F::one(); cols.AUIPC = F::one(); + // We constraint that imm_c = imm_b << 12 by looking up SLL(op_c_val, op_b_val, 12) with multiplicity AUIPC. + // Then we constraint op_a_val = op_c_val + pc by looking up ADD(op_a_val, op_c_val, pc) with multiplicity AUIPC. } - _ => {} + // Multiply instructions + Opcode::MUL + | Opcode::MULH + | Opcode::MULSU + | Opcode::MULU + | Opcode::DIV + | Opcode::DIVU + | Opcode::REM + | Opcode::REMU => { + cols.multiply_instruction = F::one(); + match opcode { + // TODO: set byte/half/word/unsigned based on which variant of multiply. + _ => {} + } + } + _ => panic!("Invalid opcode"), + } + } + + fn populate_memory(&self, cols: &mut CpuCols, event: CpuEvent) { + if let Some(memory_value) = event.memory_value { + cols.mem_val = memory_value.into(); + } + if let Some(addr) = event.addr { + cols.addr = addr.into(); + } + } + + fn populate_branch(&self, cols: &mut CpuCols, event: CpuEvent) { + if let Some(branch_condition) = event.branch_condition { + cols.branch_cond_val = (branch_condition as u32).into(); } - // TODO: make Into for Iter to Word and use that here. - cols.op_a_val = Word( - event.operands[0] - .to_le_bytes() - .iter() - .map(|v| F::from_canonical_u8(*v)) - .collect::>() - .try_into() - .unwrap(), - ); - cols.op_b_val = Word( - event.operands[1] - .to_le_bytes() - .iter() - .map(|v| F::from_canonical_u8(*v)) - .collect::>() - .try_into() - .unwrap(), - ); - cols.op_c_val = Word( - event.operands[2] - .to_le_bytes() - .iter() - .map(|v| F::from_canonical_u8(*v)) - .collect::>() - .try_into() - .unwrap(), - ); - row } } @@ -222,6 +335,9 @@ mod tests { c: 2, }, operands: [1, 2, 3], + addr: None, + memory_value: None, + branch_condition: None, }]; let chip = CpuChip:: { _phantom: Default::default(), diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index aa3d502d56..02a6046489 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -9,6 +9,7 @@ use std::{ collections::BTreeMap, fmt::{Display, Formatter}, + mem, }; use crate::{ @@ -448,11 +449,36 @@ impl Runtime { } fn emit_cpu(&mut self, clk: u32, pc: u32, instruction: Instruction, a: u32, b: u32, c: u32) { + let (addr, memory_value) = match instruction.opcode { + Opcode::LB | Opcode::LH | Opcode::LW | Opcode::LBU | Opcode::LHU => { + let addr = b.wrapping_add(c); + let memory_value = self.mr(addr); + (Some(addr), Some(memory_value)) + } + Opcode::SB | Opcode::SH | Opcode::SW => { + let addr = b.wrapping_add(c); + let memory_value = self.mr(addr); + (Some(addr), Some(memory_value)) + } + _ => (None, None), + }; + let branch_condition = match instruction.opcode { + Opcode::BEQ => Some(a == b), + Opcode::BNE => Some(a != b), + Opcode::BLT => Some((a as i32) < (b as i32)), + Opcode::BGE => Some((a as i32) >= (b as i32)), + Opcode::BLTU => Some(a < b), + Opcode::BGEU => Some(a >= b), + _ => None, + }; self.cpu_events.push(CpuEvent { clk: self.clk, pc: self.pc, instruction, operands: [a, b, c], + addr: addr, + memory_value, + branch_condition, }); } @@ -462,6 +488,7 @@ impl Runtime { let mut a: u32 = u32::MAX; let mut b: u32 = u32::MAX; let mut c: u32 = u32::MAX; + match instruction.opcode { // R-type instructions. Opcode::ADD => { @@ -723,13 +750,13 @@ impl Runtime { // Upper immediate instructions. Opcode::LUI => { let (rd, imm) = instruction.u_type(); - (b, c) = (imm, 0); // Note that we'll special-case this in the CPU table + (b, c) = (imm, 12); // Note that we'll special-case this in the CPU table a = b << 12; self.rw(rd, a); } Opcode::AUIPC => { let (rd, imm) = instruction.u_type(); - (b, c) = (imm, 0); // Note that we'll special-case this in the CPU table + (b, c) = (imm, imm << 12); // Note that we'll special-case this in the CPU table a = self.pc.wrapping_add(b << 12); self.rw(rd, a); }