Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pages in protocol #54

Closed
wants to merge 16 commits into from
6 changes: 6 additions & 0 deletions core/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ pub struct AirInteraction<E> {
///
/// All `AirBuilder` implementations automatically implement this trait.
pub trait CurtaAirBuilder: AirBuilder + MessageBuilder<AirInteraction<Self::Expr>> {
fn assert_word_zero<I: Into<Self::Expr>>(&mut self, value: Word<I>) {
for value in value.0 {
self.assert_zero(value);
}
}

fn assert_word_eq<I: Into<Self::Expr>>(&mut self, left: Word<I>, right: Word<I>) {
for (left, right) in left.0.into_iter().zip(right.0) {
self.assert_eq(left, right);
Expand Down
65 changes: 60 additions & 5 deletions core/src/memory/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub struct MemoryCols<T> {
pub is_read: Bool<T>,
/// The multiplicity of this memory access.
pub multiplicity: T,

/// The previous address of the table. Needed for the bus argument access of "less_than"
pub prev_addr: Word<T>,
/// A decoding of the clk to a 32-bit word.
Expand All @@ -53,6 +52,14 @@ pub struct MemoryCols<T> {
pub is_clk_lt: Bool<T>,
/// A flag to indicate whether the memory access consistency is checked.
pub is_checked: Bool<T>,
/// A flag to indicate whether this is the last operation for the current address.
pub is_last: Bool<T>,
/// A flag to indicate whether we have a new write event (a write with a non-zero clock cycle)
pub is_new_write: Bool<T>,
/// A flag to inidcate whether we mutated this address in the program execution.
pub is_changed: Bool<T>,
/// Oouput page multiplicity. We send a last event to the page if the data was mutated.
pub out_page_mult: T,
}

const fn make_col_map() -> MemoryCols<usize> {
Expand Down Expand Up @@ -111,9 +118,6 @@ impl<AB: CurtaAirBuilder> Air<AB> for MemoryChip {
// Assert that `clk_word` is a decoding of `clk`.
let clk_expected = reduce::<AB>(local.clk_word);
builder.assert_eq(clk_expected, local.clk);
// If the operation is a write, the multiplicity must be 1.
// TODO: Figure out if this constraint is necessary.
// builder.assert_zero(local.is_read.0 * (local.multiplicity - AB::F::one()));

// Lookup values validity checks
//
Expand Down Expand Up @@ -172,12 +176,63 @@ impl<AB: CurtaAirBuilder> Air<AB> for MemoryChip {
.when(next.is_clk_eq.0)
.assert_eq(next.clk, local.clk);

// An operation is the last one for an adress if the next address is different from the
// current one.
builder
.when_transition()
.when(local.multiplicity)
.assert_one(local.is_last.0 + next.is_addr_eq.0);
// In the last row, record `is_last` as `1` if the multiplicity is non-zero.
// builder
// .when_last_row()
// .when(local.multiplicity)
// .assert_eq(local.is_last.0, AB::F::one());

// Constrain the `is_new_write` flag.
//
// The `is_new_write` flag is set to `1` if the current operation is a write event with a
// a non-zero `clk` (so that's not a write from the initial memory state).
builder
.when(local.clk)
.when(local.is_read.0 - AB::F::one())
.assert_eq(local.is_new_write.0, AB::F::one());

// Constrain the `is_changed` flag.
//
// The `is_changed` flag is set to `1` if the `is_new_write` flag is set to `1` at any point
// in the memory operation for the current address. This can be constrained as follows:
// + `is_changed` is equal to `is_new_write` whenever the current address is new.
// + When the address is the same as the previous one, `is_changed` is equal to the
// OR of `is_changed` from the last row and `is_new_write`.
builder
.when(local.is_addr_eq.0 - AB::F::one())
.assert_eq(local.is_changed.0, local.is_new_write.0);
builder
.when_transition()
.when(next.multiplicity)
.when(next.is_addr_eq.0)
.assert_eq(
next.is_changed.0,
local.is_changed.0 + next.is_new_write.0 - local.is_changed.0 * next.is_new_write.0,
);

// Constrain the `out_page_mult` flag. This flag is set to the AND of `is_changed` and
// `is_last`, so that we send the last event of an address if the data was mutated. These
// hold while the event is not padding (i.e. the multiplicity is non-zero).
builder
.when(local.multiplicity)
.assert_eq(local.out_page_mult, local.is_last.0 * local.is_changed.0);
builder
.when(local.multiplicity - AB::F::one())
.assert_zero(local.out_page_mult);

// At every row, record the memory interaction.
builder.recieve_memory(
local.clk,
local.addr,
local.value,
local.is_read.0,
local.multiplicity,
local.multiplicity, //+ local.out_page_mult,
);
}
}
29 changes: 5 additions & 24 deletions core/src/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
pub mod air;
pub mod page;
pub mod state;
pub mod state_old;
pub mod trace;

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -46,34 +49,13 @@ mod tests {

use crate::lookup::InteractionBuilder;
use crate::memory::{MemOp, MemoryChip};

use crate::runtime::tests::simple_program;
use crate::runtime::Runtime;
use crate::utils::Chip;

use p3_commit::ExtensionMmcs;

use super::air::NUM_MEMORY_COLS;
use super::MemoryEvent;

#[test]
fn test_memory_generate_trace() {
let events = vec![
MemoryEvent {
clk: 0,
addr: 0,
op: MemOp::Write,
value: 0,
},
MemoryEvent {
clk: 1,
addr: 0,
op: MemOp::Read,
value: 0,
},
];
let trace: RowMajorMatrix<BabyBear> = MemoryChip::generate_trace(&events);
println!("{:?}", trace.values)
}

#[test]
fn test_memory_prove_babybear() {
Expand Down Expand Up @@ -120,10 +102,9 @@ mod tests {
let program = simple_program();
let mut runtime = Runtime::new(program);
runtime.run();
let events = runtime.memory_events;

let trace: RowMajorMatrix<BabyBear> = MemoryChip::generate_trace(&events);
let air = MemoryChip::new();
let trace: RowMajorMatrix<BabyBear> = air.generate_trace(&mut runtime);
let proof = prove::<MyConfig, _>(&config, &air, &mut challenger, trace);

let mut challenger = Challenger::new(perm);
Expand Down
85 changes: 85 additions & 0 deletions core/src/memory/page/air.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use core::borrow::Borrow;
use core::borrow::BorrowMut;
use core::mem::size_of;

use crate::air::{Bool, CurtaAirBuilder, Word};
use p3_air::Air;
use p3_air::BaseAir;
use p3_field::AbstractField;

use p3_field::Field;
use p3_matrix::MatrixRowSlices;
use valida_derive::AlignedBorrow;

use super::InputPage;
use super::OutputPage;

pub const NUM_PAGE_COLS: usize = size_of::<PageCols<u8>>();
pub const NUM_OUT_PAGE_COLS: usize = size_of::<OutputPageCols<u8>>();

#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct PageCols<T> {
/// The address of the memory access.
pub addr: Word<T>,
/// The value being read from or written to memory.
pub value: Word<T>,
}

// #[derive(Debug, Clone, AlignedBorrow)]
// #[repr(C)]
// pub struct InputPageCols<T> {

// }

#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct OutputPageCols<T> {
/// The clock cycle value for this memory access.
pub clk: T,
/// Whether the memory was being read from or written to.
pub is_read: Bool<T>,
}

impl<F: Field> BaseAir<F> for InputPage {
fn width(&self) -> usize {
NUM_PAGE_COLS
}
}

impl<AB: CurtaAirBuilder> Air<AB> for InputPage {
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &PageCols<AB::Var> = main.row_slice(0).borrow();

builder.send_memory(
AB::F::zero(),
local.addr,
local.value,
AB::F::zero(),
AB::F::one(),
)
}
}

impl<F: Field> BaseAir<F> for OutputPage {
fn width(&self) -> usize {
NUM_PAGE_COLS + NUM_OUT_PAGE_COLS
}
}

impl<AB: CurtaAirBuilder> Air<AB> for OutputPage {
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &PageCols<AB::Var> = main.row_slice(0).borrow();
let out: &OutputPageCols<AB::Var> = main.row_slice(NUM_PAGE_COLS).borrow();

builder.send_memory(
out.clk,
local.addr,
local.value,
out.is_read.0,
AB::F::one(),
)
}
}
35 changes: 35 additions & 0 deletions core/src/memory/page/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
mod air;
mod trace;

pub const PAGE_DEGREE: usize = 10;
pub const PAGE_SIZE: usize = 1 << PAGE_DEGREE;

#[derive(Debug, Clone, Copy)]
pub struct InputPage {
page_id: u16,
}

#[derive(Debug, Clone, Copy)]
pub struct OutputPage {
page_id: u16,
}

impl InputPage {
pub fn new(page_id: u16) -> Self {
Self { page_id }
}

pub fn page_id(&self) -> u16 {
self.page_id
}
}

impl OutputPage {
pub fn new(page_id: u16) -> Self {
Self { page_id }
}

pub fn page_id(&self) -> u16 {
self.page_id
}
}
80 changes: 80 additions & 0 deletions core/src/memory/page/trace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::borrow::BorrowMut;

use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use crate::{
air::{Bool, Word},
memory::{MemOp, MemoryEvent},
};

use super::{
air::{OutputPageCols, PageCols, NUM_OUT_PAGE_COLS, NUM_PAGE_COLS},
InputPage, OutputPage,
};

pub struct OutputPageTrace<T> {
pub(crate) page: RowMajorMatrix<T>,
pub(crate) data: RowMajorMatrix<T>,
}

impl InputPage {
pub(crate) fn generate_trace<F: Field>(&self, in_events: &[MemoryEvent]) -> RowMajorMatrix<F> {
let rows = in_events
.par_iter()
.flat_map(|event| {
let mut row = [F::zero(); NUM_PAGE_COLS];

let cols: &mut PageCols<F> = row.as_mut_slice().borrow_mut();

cols.addr = Word::from(event.addr);
cols.value = Word::from(event.value);

row
})
.collect::<Vec<_>>();

RowMajorMatrix::new(rows, NUM_PAGE_COLS)
}
}

impl OutputPage {
pub(crate) fn generate_trace<F: Field>(
&self,
out_events: &[MemoryEvent],
) -> OutputPageTrace<F> {
let page_rows = out_events
.par_iter()
.flat_map(|event| {
let mut row = [F::zero(); NUM_PAGE_COLS];

let cols: &mut PageCols<F> = row.as_mut_slice().borrow_mut();

cols.addr = Word::from(event.addr);
cols.value = Word::from(event.value);

row
})
.collect::<Vec<_>>();

let data_rows = out_events
.par_iter()
.flat_map(|event| {
let mut row = [F::zero(); NUM_OUT_PAGE_COLS];

let cols: &mut OutputPageCols<F> = row.as_mut_slice().borrow_mut();

cols.clk = F::from_canonical_u32(event.clk);
cols.is_read = Bool::from(event.op == MemOp::Read);

row
})
.collect::<Vec<_>>();

OutputPageTrace {
page: RowMajorMatrix::new(page_rows, NUM_PAGE_COLS),
data: RowMajorMatrix::new(data_rows, NUM_OUT_PAGE_COLS),
}
}
}
1 change: 1 addition & 0 deletions core/src/memory/state/merkle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading