Skip to content

Commit

Permalink
Add context pruning (#112)
Browse files Browse the repository at this point in the history
* Implement context pruning

* Fix clippy

* Apply comments

* Remove dummy CTL table

* Apply comments

* Apply comments

* Apply comments
  • Loading branch information
hratoanina authored Apr 2, 2024
1 parent d36d19b commit 867a172
Show file tree
Hide file tree
Showing 12 changed files with 340 additions and 73 deletions.
18 changes: 15 additions & 3 deletions evm_arithmetization/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ use starky::stark::Stark;
use crate::arithmetic::arithmetic_stark;
use crate::arithmetic::arithmetic_stark::ArithmeticStark;
use crate::byte_packing::byte_packing_stark::{self, BytePackingStark};
use crate::cpu::cpu_stark;
use crate::cpu::cpu_stark::CpuStark;
use crate::cpu::cpu_stark::{self, ctl_context_pruning_looked};
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::keccak::keccak_stark;
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_sponge::columns::KECCAK_RATE_BYTES;
use crate::keccak_sponge::keccak_sponge_stark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark;
use crate::memory::memory_stark::MemoryStark;
use crate::memory::memory_stark::{self, ctl_context_pruning_looking};
use crate::memory_continuation::memory_continuation_stark::{self, MemoryContinuationStark};
use crate::{logic, memory_continuation};

Expand Down Expand Up @@ -134,14 +134,18 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
ctl_memory(),
ctl_mem_before(),
ctl_mem_after(),
ctl_context_pruning(),
]
}

/// `CrossTableLookup` for `ArithmeticStark`, to connect it with the `Cpu`
/// module.
fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![cpu_stark::ctl_arithmetic_base_rows()],
vec![
cpu_stark::ctl_arithmetic_base_rows(),
cpu_stark::ctl_arithmetic_context_pruning(),
],
arithmetic_stark::ctl_arithmetic_rows(),
)
}
Expand Down Expand Up @@ -327,6 +331,14 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(all_lookers, memory_looked)
}

/// `CrossTableLookup` for `Cpu` to propagate stale contexts to `Memory`.
fn ctl_context_pruning<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![ctl_context_pruning_looking()],
ctl_context_pruning_looked(),
)
}

/// `CrossTableLookup` for `MemBefore` table to connect it with the `Memory`
/// module.
fn ctl_mem_before<F: Field>() -> CrossTableLookup<F> {
Expand Down
21 changes: 21 additions & 0 deletions evm_arithmetization/src/cpu/columns/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
jumps: CpuJumpsView<T>,
shift: CpuShiftView<T>,
stack: CpuStackView<T>,
context_pruning: CpuContextPruningView<T>,
}

impl<T: Copy> CpuGeneralColumnsView<T> {
Expand Down Expand Up @@ -75,6 +76,18 @@ impl<T: Copy> CpuGeneralColumnsView<T> {
pub(crate) fn stack_mut(&mut self) -> &mut CpuStackView<T> {
unsafe { &mut self.stack }
}

/// View of the column for context pruning.
/// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn context_pruning(&self) -> &CpuContextPruningView<T> {
unsafe { &self.context_pruning }
}

/// Mutable view of the column for context pruning.
/// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn context_pruning_mut(&mut self) -> &mut CpuContextPruningView<T> {
unsafe { &mut self.context_pruning }
}
}

impl<T: Copy + PartialEq> PartialEq<Self> for CpuGeneralColumnsView<T> {
Expand Down Expand Up @@ -142,6 +155,14 @@ pub(crate) struct CpuShiftView<T: Copy> {
pub(crate) high_limb_sum_inv: T,
}

/// View of the first `CpuGeneralColumns` storing a flag for context pruning.
#[derive(Copy, Clone)]
pub(crate) struct CpuContextPruningView<T: Copy> {
/// The flag is 1 if the OP flag `context_op` is set, the operation is
/// `SET_CONTEXT` and `new_ctx < old_ctx`, and 0 otherwise.
pub(crate) pruning_flag: T,
}

/// View of the last four `CpuGeneralColumns` storing stack-related variables.
/// The first three are used for conditionally enabling and disabling channels
/// when reading the next `stack_top`, and the fourth one is used to check for
Expand Down
47 changes: 47 additions & 0 deletions evm_arithmetization/src/cpu/cpu_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,53 @@ pub(crate) fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
)
}

/// Returns the `TableWithColumns` for the context pruning inequality. It's a LT
/// operation on `new_ctx` and `old_ctx`.
pub(crate) fn ctl_arithmetic_context_pruning<F: Field>() -> TableWithColumns<F> {
// Opcode is LT.
let mut columns = vec![Column::constant(F::from_canonical_usize(0x10))];
// `input0` = `new_ctx`. Context is shifted; higher limbs are all zero.
columns.push(Column::single(COL_MAP.mem_channels[0].value[2]));
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1));
// `input1` = `old_ctx`. Higher limbs are all zero.
columns.push(Column::single(COL_MAP.context));
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1));
// `input2` doesn't matter.
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS));
// `res` is the first general column. Higher limbs are all zero.
columns.push(Column::single(
COL_MAP.general.context_pruning().pruning_flag,
));
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1));

TableWithColumns::new(
*Table::Cpu,
columns,
Some(Filter::new(
vec![(
Column::single(COL_MAP.op.context_op),
Column::single(COL_MAP.opcode_bits[0]),
)],
vec![],
)),
)
}

/// Returns a column containing stale contexts.
pub(crate) fn ctl_context_pruning_looked<F: Field>() -> TableWithColumns<F> {
TableWithColumns::new(
*Table::Cpu,
vec![Column::single(COL_MAP.context)],
Some(Filter::new(
vec![(
Column::single(COL_MAP.op.context_op),
Column::single(COL_MAP.general.context_pruning().pruning_flag),
)],
vec![],
)),
)
}

/// Creates the vector of `Columns` corresponding to the contents of General
/// Purpose channels when calling byte packing. We use `ctl_data_keccak_sponge`
/// because the `Columns` are the same as the ones computed for
Expand Down
10 changes: 7 additions & 3 deletions evm_arithmetization/src/cpu/kernel/asm/account_code.asm
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,20 @@ global sys_extcodesize:
SWAP1
// stack: address, kexit_info
%extcodesize
// stack: code_size, codesize_ctx, kexit_info
SWAP1
// stack: codesize_ctx, code_size, kexit_info
%prune_context
// stack: code_size, kexit_info
SWAP1
EXIT_KERNEL

// Pre stack: address, retdest
// Post stack: code_size, codesize_ctx
global extcodesize:
// stack: address, retdest
%next_context_id
// stack: codesize_ctx, address, retdest
SWAP1
// stack: address, codesize_ctx, retdest
%stack(codesize_ctx, address, retdest) -> (address, codesize_ctx, retdest, codesize_ctx)
%jump(load_code)

// Loads the code at `address` into memory, in the code segment of the given context, starting at offset 0.
Expand Down
69 changes: 66 additions & 3 deletions evm_arithmetization/src/cpu/kernel/asm/memory/syscalls.asm
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ calldataload_large_offset:
codecopy_within_bounds:
// stack: total_size, segment, src_ctx, kexit_info, dest_offset, offset, size
POP
// stack: segment, src_ctx, kexit_info, dest_offset, offset, size
GET_CONTEXT
%stack (context, segment, src_ctx, kexit_info, dest_offset, offset, size) ->
(src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, size, codecopy_after, src_ctx, kexit_info)
%build_address
SWAP3 %build_address
// stack: DST, SRC, size, codecopy_after, src_ctx, kexit_info
%jump(memcpy_bytes)

wcopy_within_bounds:
// stack: segment, src_ctx, kexit_info, dest_offset, offset, size
GET_CONTEXT
Expand All @@ -131,7 +140,15 @@ wcopy_empty:

codecopy_large_offset:
// stack: total_size, src_ctx, kexit_info, dest_offset, offset, size
%pop2
POP
// offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros.
// stack: src_ctx, kexit_info, dest_offset, offset, size
GET_CONTEXT
%stack (context, src_ctx, kexit_info, dest_offset, offset, size) ->
(context, @SEGMENT_MAIN_MEMORY, dest_offset, size, codecopy_after, src_ctx, kexit_info)
%build_address
%jump(memset)

wcopy_large_offset:
// offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros.
// stack: kexit_info, dest_offset, offset, size
Expand All @@ -141,6 +158,24 @@ wcopy_large_offset:
%build_address
%jump(memset)

codecopy_after:
// stack: src_ctx, kexit_info
DUP1 GET_CONTEXT
// stack: ctx, src_ctx, src_ctx, kexit_info
// If ctx == src_ctx, it's a CODECOPY, and we don't need to prune the context.
EQ
// stack: ctx == src_ctx, src_ctx, kexit_info
%jumpi(codecopy_no_prune)
// stack: src_ctx, kexit_info
%prune_context
// stack: kexit_info
EXIT_KERNEL

codecopy_no_prune:
// stack: src_ctx, kexit_info
POP
EXIT_KERNEL

wcopy_after:
// stack: kexit_info
EXIT_KERNEL
Expand Down Expand Up @@ -248,9 +283,37 @@ extcodecopy_contd:

GET_CONTEXT
%stack (context, new_dest_offset, copy_size, extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size) ->
(src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size)
(src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, copy_size, codecopy_large_offset, copy_size, src_ctx, kexit_info, new_dest_offset, offset, extra_size)
%build_address
SWAP3 %build_address
// stack: DST, SRC, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size
// stack: DST, SRC, copy_size, codecopy_large_offset, copy_size, src_ctx, kexit_info, new_dest_offset, offset, extra_size
%jump(memcpy_bytes)
%endmacro

// Adds stale_ctx to the list of stale contexts. You need to return to a previous, older context with
// a SET_CONTEXT instruction. By assumption, stale_ctx is greater than the current context.
%macro prune_context
// stack: stale_ctx
GET_CONTEXT
// stack: curr_ctx, stale_ctx
// When we go to stale_ctx, we want its stack to contain curr_ctx so that we can immediately
// call SET_CONTEXT. For that, we need a stack length of 1, and store curr_ctx in Segment::Stack[0].
PUSH @SEGMENT_STACK
DUP3 ADD
// stack: stale_ctx_stack_addr, curr_ctx, stale_ctx
DUP2
// stack: curr_ctx, stale_ctx_stack_addr, curr_ctx, stale_ctx
MSTORE_GENERAL
// stack: curr_ctx, stale_ctx
PUSH @CTX_METADATA_STACK_SIZE
DUP3 ADD
// stack: stale_ctx_stack_size_addr, curr_ctx, stale_ctx
PUSH 1
MSTORE_GENERAL
// stack: curr_ctx, stale_ctx
POP
SET_CONTEXT
// We're now in stale_ctx, with stack: curr_ctx
SET_CONTEXT
// We're now in curr_ctx, with an empty stack.
%endmacro
5 changes: 4 additions & 1 deletion evm_arithmetization/src/cpu/kernel/tests/account_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ fn test_extcodesize() -> Result<()> {
HashMap::from([(keccak(&code), code.clone())]);
interpreter.run()?;

assert_eq!(interpreter.stack(), vec![code.len().into()]);
assert_eq!(
interpreter.stack(),
vec![U256::one() << CONTEXT_SCALING_FACTOR, code.len().into()]
);

Ok(())
}
Expand Down
11 changes: 8 additions & 3 deletions evm_arithmetization/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,14 @@ pub fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
let (tables, final_values) = timed!(
timing,
"convert trace data to tables",
state
.traces
.into_tables(all_stark, &memory_before, trace_lengths, config, timing)
state.traces.into_tables(
all_stark,
&memory_before,
state.stale_contexts,
trace_lengths,
config,
timing
)
);
Ok((tables, public_values, final_values))
}
Expand Down
18 changes: 16 additions & 2 deletions evm_arithmetization/src/generation/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::memory::segments::Segment;
use crate::util::u256_to_usize;
use crate::witness::errors::ProgramError;
use crate::witness::memory::MemoryChannel::GeneralPurpose;
use crate::witness::memory::{MemoryAddress, MemoryOp, MemoryState};
use crate::witness::memory::{MemoryAddress, MemoryContextState, MemoryOp, MemoryState};
use crate::witness::memory::{MemoryOpKind, MemorySegmentState};
use crate::witness::operation::{generate_exception, Operation};
use crate::witness::state::RegistersState;
Expand Down Expand Up @@ -212,7 +212,15 @@ pub(crate) trait State<F: Field> {
if !running {
assert_eq!(self.get_clock() - final_clock, NUM_EXTRA_CYCLES_AFTER - 1);
}
let final_mem = self.get_full_memory();
let final_mem = if let Some(mut mem) = self.get_full_memory() {
// Clear memory we will not use again.
for &ctx in &self.get_generation_state().stale_contexts {
mem.contexts[ctx] = MemoryContextState::default();
}
Some(mem)
} else {
None
};
#[cfg(not(test))]
self.log_info(format!("CPU halted after {} cycles", self.get_clock()));
return Ok((final_registers, final_mem));
Expand Down Expand Up @@ -326,6 +334,10 @@ pub struct GenerationState<F: Field> {
pub(crate) memory: MemoryState,
pub(crate) traces: Traces<F>,

/// Memory used by stale contexts can be pruned so proving segments can be
/// smaller.
pub(crate) stale_contexts: Vec<usize>,

/// Prover inputs containing RLP data, in reverse order so that the next
/// input can be obtained via `pop()`.
pub(crate) rlp_prover_inputs: Vec<U256>,
Expand Down Expand Up @@ -375,6 +387,7 @@ impl<F: Field> GenerationState<F> {
registers: Default::default(),
memory: MemoryState::new(kernel_code),
traces: Traces::default(),
stale_contexts: Vec::new(),
rlp_prover_inputs,
withdrawal_prover_inputs,
state_key_to_address: HashMap::new(),
Expand Down Expand Up @@ -462,6 +475,7 @@ impl<F: Field> GenerationState<F> {
registers: self.registers,
memory: self.memory.clone(),
traces: Traces::default(),
stale_contexts: Vec::new(),
rlp_prover_inputs: self.rlp_prover_inputs.clone(),
state_key_to_address: self.state_key_to_address.clone(),
bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(),
Expand Down
19 changes: 18 additions & 1 deletion evm_arithmetization/src/memory/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,25 @@ pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1;
// Contains `next_segment * addr_changed * next_is_read`.
pub(crate) const INITIALIZE_AUX: usize = VIRTUAL_FIRST_CHANGE + 1;

// Contains `row_index` if and only if context `row_index` is stale,
// and zero if not.
pub(crate) const STALE_CONTEXTS: usize = INITIALIZE_AUX + 1;

// Pseudo-inverse of `STALE_CONTEXTS`. Used to ascertain it's nonzero.
pub(crate) const STALE_CONTEXTS_INV: usize = STALE_CONTEXTS + 1;

// Used for the context pruning lookup.
pub(crate) const STALE_CONTEXTS_FREQUENCIES: usize = STALE_CONTEXTS_INV + 1;

// Flag indicating whether the row should be pruned, i.e. whether its
// `ADDR_CONTEXT` is in `STALE_CONTEXTS`.
pub(crate) const IS_STALE: usize = STALE_CONTEXTS_FREQUENCIES + 1;

// Filter for the `MemAfter` CTL.
pub(crate) const MEM_AFTER_FILTER: usize = IS_STALE + 1;

// We use a range check to enforce the ordering.
pub(crate) const RANGE_CHECK: usize = INITIALIZE_AUX + 1;
pub(crate) const RANGE_CHECK: usize = MEM_AFTER_FILTER + 1;
/// The counter column (used for the range check) starts from 0 and increments.
pub(crate) const COUNTER: usize = RANGE_CHECK + 1;
/// The frequencies column used in logUp.
Expand Down
Loading

0 comments on commit 867a172

Please sign in to comment.