Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context pruning #112

Merged
merged 8 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions evm_arithmetization/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ use starky::stark::Stark;
use crate::arithmetic::arithmetic_stark;
use crate::arithmetic::arithmetic_stark::ArithmeticStark;
use crate::byte_packing::byte_packing_stark::{self, BytePackingStark};
use crate::cpu::cpu_stark;
use crate::cpu::cpu_stark::CpuStark;
use crate::cpu::cpu_stark::{self, ctl_context_pruning_looked};
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::keccak::keccak_stark;
use crate::keccak::keccak_stark::KeccakStark;
use crate::keccak_sponge::columns::KECCAK_RATE_BYTES;
use crate::keccak_sponge::keccak_sponge_stark;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark;
use crate::logic::LogicStark;
use crate::memory::memory_stark;
use crate::memory::memory_stark::MemoryStark;
use crate::memory::memory_stark::{self, ctl_context_pruning_looking};
use crate::memory_continuation::memory_continuation_stark::{self, MemoryContinuationStark};
use crate::{logic, memory_continuation};

Expand Down Expand Up @@ -134,14 +134,18 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
ctl_memory(),
ctl_mem_before(),
ctl_mem_after(),
ctl_context_pruning(),
]
}

/// `CrossTableLookup` for `ArithmeticStark`, to connect it with the `Cpu`
/// module.
fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![cpu_stark::ctl_arithmetic_base_rows()],
vec![
cpu_stark::ctl_arithmetic_base_rows(),
cpu_stark::ctl_arithmetic_context_pruning(),
],
arithmetic_stark::ctl_arithmetic_rows(),
)
}
Expand Down Expand Up @@ -327,6 +331,14 @@ fn ctl_memory<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(all_lookers, memory_looked)
}

/// `CrossTableLookup` for `Cpu` to propagate pruned contexts to `Memory`.
fn ctl_context_pruning<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![ctl_context_pruning_looking()],
ctl_context_pruning_looked(),
)
}

/// `CrossTableLookup` for `MemBefore` table to connect it with the `Memory`
/// module.
fn ctl_mem_before<F: Field>() -> CrossTableLookup<F> {
Expand Down
21 changes: 21 additions & 0 deletions evm_arithmetization/src/cpu/columns/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
jumps: CpuJumpsView<T>,
shift: CpuShiftView<T>,
stack: CpuStackView<T>,
context_pruning: CpuContextPruningView<T>,
}

impl<T: Copy> CpuGeneralColumnsView<T> {
Expand Down Expand Up @@ -75,6 +76,18 @@ impl<T: Copy> CpuGeneralColumnsView<T> {
pub(crate) fn stack_mut(&mut self) -> &mut CpuStackView<T> {
unsafe { &mut self.stack }
}

/// View of the column for context pruning.
/// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn context_pruning(&self) -> &CpuContextPruningView<T> {
unsafe { &self.context_pruning }
}

/// Mutable view of the column for context pruning.
/// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn context_pruning_mut(&mut self) -> &mut CpuContextPruningView<T> {
unsafe { &mut self.context_pruning }
}
}

impl<T: Copy + PartialEq> PartialEq<Self> for CpuGeneralColumnsView<T> {
Expand Down Expand Up @@ -142,6 +155,14 @@ pub(crate) struct CpuShiftView<T: Copy> {
pub(crate) high_limb_sum_inv: T,
}

/// View of the first `CpuGeneralColumns` storing a flag for context pruning.
#[derive(Copy, Clone)]
pub(crate) struct CpuContextPruningView<T: Copy> {
/// The flag is 1 if the OP flag `context_op` is set, the operation is
/// `SET_CONTEXT` and `new_ctx < old_ctx`, and 0 otherwise.
pub(crate) pruning_flag: T,
}

/// View of the last four `CpuGeneralColumns` storing stack-related variables.
/// The first three are used for conditionally enabling and disabling channels
/// when reading the next `stack_top`, and the fourth one is used to check for
Expand Down
47 changes: 47 additions & 0 deletions evm_arithmetization/src/cpu/cpu_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,53 @@ pub(crate) fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
)
}

/// Returns the `TableWithColumns` for the context pruning inequality. It's a LT
/// operation on `new_ctx` and `old_ctx`.
pub(crate) fn ctl_arithmetic_context_pruning<F: Field>() -> TableWithColumns<F> {
// Opcode is LT.
let mut columns = vec![Column::constant(F::from_canonical_usize(0x10))];
// `input0` = `new_ctx`. Context is shifted; higher limbs are all zero.
columns.push(Column::single(COL_MAP.mem_channels[0].value[2]));
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1));
// `input1` = `old_ctx`. Higher limbs are all zero.
columns.push(Column::single(COL_MAP.context));
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1));
// `input2` doesn't matter.
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS));
// `res` is the first general column. Higher limbs are all zero.
columns.push(Column::single(
COL_MAP.general.context_pruning().pruning_flag,
));
columns.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1));

TableWithColumns::new(
*Table::Cpu,
columns,
Some(Filter::new(
vec![(
Column::single(COL_MAP.op.context_op),
Column::single(COL_MAP.opcode_bits[0]),
)],
vec![],
)),
)
}

/// Returns a column containing pruned contexts.
pub(crate) fn ctl_context_pruning_looked<F: Field>() -> TableWithColumns<F> {
TableWithColumns::new(
*Table::Cpu,
vec![Column::single(COL_MAP.context)],
Some(Filter::new(
vec![(
Column::single(COL_MAP.op.context_op),
Column::single(COL_MAP.general.context_pruning().pruning_flag),
)],
vec![],
)),
)
}

/// Creates the vector of `Columns` corresponding to the contents of General
/// Purpose channels when calling byte packing. We use `ctl_data_keccak_sponge`
/// because the `Columns` are the same as the ones computed for
Expand Down
10 changes: 7 additions & 3 deletions evm_arithmetization/src/cpu/kernel/asm/account_code.asm
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,20 @@ global sys_extcodesize:
SWAP1
// stack: address, kexit_info
%extcodesize
// stack: code_size, codesize_ctx, kexit_info
SWAP1
// stack: codesize_ctx, code_size, kexit_info
%prune_context
// stack: code_size, kexit_info
SWAP1
EXIT_KERNEL

// Pre stack: address, retdest
// Post stack: code_size, codesize_ctx
global extcodesize:
// stack: address, retdest
%next_context_id
// stack: codesize_ctx, address, retdest
SWAP1
// stack: address, codesize_ctx, retdest
%stack(codesize_ctx, address, retdest) -> (address, codesize_ctx, retdest, codesize_ctx)
%jump(load_code)

// Loads the code at `address` into memory, in the code segment of the given context, starting at offset 0.
Expand Down
69 changes: 66 additions & 3 deletions evm_arithmetization/src/cpu/kernel/asm/memory/syscalls.asm
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ calldataload_large_offset:
codecopy_within_bounds:
// stack: total_size, segment, src_ctx, kexit_info, dest_offset, offset, size
POP
// stack: segment, src_ctx, kexit_info, dest_offset, offset, size
Copy link
Contributor

@4l0n50 4l0n50 Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to add these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codecopy_within_bounds calls codecopy_after instead of wcopy_after. I agree that's more or less code duplication, but I didn't see an easy way to do it more cleanly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My problem is that I can't see where are you changing the context here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the problem with codecopy etc. is that they do not change contexts, but write code in another context. And the idea here is to still be able to prune that extra context when needed by setting the context in codecopy_after.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, and why src_ctx will be never accessed again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of extcodecopy and extcodesize, we generate src_ctx with next_context_id which increments @GLOBAL_METADATA_LARGEST_CONTEXT, meaning that any other created context will have a different id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh ok, I think see now. It's because you just need the new_ctx for writing the code there (and check the hash etc) but after you copy the code you can forget about it?

GET_CONTEXT
%stack (context, segment, src_ctx, kexit_info, dest_offset, offset, size) ->
(src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, size, codecopy_after, src_ctx, kexit_info)
%build_address
SWAP3 %build_address
// stack: DST, SRC, size, codecopy_after, src_ctx, kexit_info
%jump(memcpy_bytes)

wcopy_within_bounds:
// stack: segment, src_ctx, kexit_info, dest_offset, offset, size
GET_CONTEXT
Expand All @@ -131,7 +140,15 @@ wcopy_empty:

codecopy_large_offset:
// stack: total_size, src_ctx, kexit_info, dest_offset, offset, size
%pop2
POP
// offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros.
// stack: src_ctx, kexit_info, dest_offset, offset, size
GET_CONTEXT
%stack (context, src_ctx, kexit_info, dest_offset, offset, size) ->
(context, @SEGMENT_MAIN_MEMORY, dest_offset, size, codecopy_after, src_ctx, kexit_info)
%build_address
%jump(memset)

wcopy_large_offset:
// offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros.
// stack: kexit_info, dest_offset, offset, size
Expand All @@ -141,6 +158,24 @@ wcopy_large_offset:
%build_address
%jump(memset)

codecopy_after:
// stack: src_ctx, kexit_info
DUP1 GET_CONTEXT
// stack: ctx, src_ctx, src_ctx, kexit_info
// If ctx == src_ctx, it's a CODECOPY, and we don't need to prune the context.
EQ
// stack: ctx == src_ctx, src_ctx, kexit_info
%jumpi(codecopy_no_prune)
// stack: src_ctx, kexit_info
%prune_context
// stack: kexit_info
EXIT_KERNEL

codecopy_no_prune:
// stack: src_ctx, kexit_info
POP
EXIT_KERNEL

wcopy_after:
// stack: kexit_info
EXIT_KERNEL
Expand Down Expand Up @@ -248,9 +283,37 @@ extcodecopy_contd:

GET_CONTEXT
%stack (context, new_dest_offset, copy_size, extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size) ->
(src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size)
(src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, copy_size, codecopy_large_offset, copy_size, src_ctx, kexit_info, new_dest_offset, offset, extra_size)
%build_address
SWAP3 %build_address
// stack: DST, SRC, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size
// stack: DST, SRC, copy_size, codecopy_large_offset, copy_size, src_ctx, kexit_info, new_dest_offset, offset, extra_size
%jump(memcpy_bytes)
%endmacro

// Adds pruned_ctx to the list of pruned contexts. You need to return to a previous, older context with
// a SET_CONTEXT instruction. By assumption, pruned_ctx is greater than the current context.
%macro prune_context
// stack: pruned_ctx
GET_CONTEXT
// stack: curr_ctx, pruned_ctx
// When we go to pruned_ctx, we want its stack to contain curr_ctx so that we can immediately
// call SET_CONTEXT. For that, we need a stack length of 1, and store curr_ctx in Segment::Stack[0].
4l0n50 marked this conversation as resolved.
Show resolved Hide resolved
PUSH @SEGMENT_STACK
DUP3 ADD
// stack: pruned_ctx_stack_addr, curr_ctx, pruned_ctx
DUP2
// stack: curr_ctx, pruned_ctx_stack_addr, curr_ctx, pruned_ctx
MSTORE_GENERAL
// stack: curr_ctx, pruned_ctx
PUSH @CTX_METADATA_STACK_SIZE
DUP3 ADD
// stack: pruned_ctx_stack_size_addr, curr_ctx, pruned_ctx
PUSH 1
MSTORE_GENERAL
// stack: curr_ctx, pruned_ctx
POP
SET_CONTEXT
// We're now in pruned_ctx, with stack: curr_ctx
SET_CONTEXT
// We're now in curr_ctx, with an empty stack.
%endmacro
5 changes: 4 additions & 1 deletion evm_arithmetization/src/cpu/kernel/tests/account_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ fn test_extcodesize() -> Result<()> {
HashMap::from([(keccak(&code), code.clone())]);
interpreter.run()?;

assert_eq!(interpreter.stack(), vec![code.len().into()]);
assert_eq!(
interpreter.stack(),
vec![U256::one() << CONTEXT_SCALING_FACTOR, code.len().into()]
);
Comment on lines +181 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DQ: why do we expect this to be the content of the stack here? And why only for EXTCODESIZE but not for other tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because extcodesize (not sys_extcodesize) now returns code_size, codesize_ctx. I added the returned stack in a comment in the ASM.


Ok(())
}
Expand Down
11 changes: 8 additions & 3 deletions evm_arithmetization/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,14 @@ pub fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
let (tables, final_values) = timed!(
timing,
"convert trace data to tables",
state
.traces
.into_tables(all_stark, &memory_before, trace_lengths, config, timing)
state.traces.into_tables(
all_stark,
&memory_before,
state.pruned_contexts,
trace_lengths,
config,
timing
)
);
Ok((tables, public_values, final_values))
}
Expand Down
18 changes: 16 additions & 2 deletions evm_arithmetization/src/generation/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::memory::segments::Segment;
use crate::util::u256_to_usize;
use crate::witness::errors::ProgramError;
use crate::witness::memory::MemoryChannel::GeneralPurpose;
use crate::witness::memory::{MemoryAddress, MemoryOp, MemoryState};
use crate::witness::memory::{MemoryAddress, MemoryContextState, MemoryOp, MemoryState};
use crate::witness::memory::{MemoryOpKind, MemorySegmentState};
use crate::witness::operation::{generate_exception, Operation};
use crate::witness::state::RegistersState;
Expand Down Expand Up @@ -212,7 +212,15 @@ pub(crate) trait State<F: Field> {
if !running {
assert_eq!(self.get_clock() - final_clock, NUM_EXTRA_CYCLES_AFTER - 1);
}
let final_mem = self.get_full_memory();
let final_mem = if let Some(mut mem) = self.get_full_memory() {
// Clear memory we will not use again.
for &ctx in &self.get_generation_state().pruned_contexts {
hratoanina marked this conversation as resolved.
Show resolved Hide resolved
mem.contexts[ctx] = MemoryContextState::default();
}
Some(mem)
} else {
None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education, what is the case when get_full_memory is None? What does that mean that GenerationState use the default impl and returns None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you're generating the proof (with the full trace), you don't need this memory. The Memory trace will be generated from the list of MemOps instead.

};
#[cfg(not(test))]
self.log_info(format!("CPU halted after {} cycles", self.get_clock()));
return Ok((final_registers, final_mem));
Expand Down Expand Up @@ -326,6 +334,10 @@ pub struct GenerationState<F: Field> {
pub(crate) memory: MemoryState,
pub(crate) traces: Traces<F>,

/// Memory used by stale contexts can be pruned so proving segments can be
/// smaller.
pub(crate) pruned_contexts: Vec<usize>,
hratoanina marked this conversation as resolved.
Show resolved Hide resolved

/// Prover inputs containing RLP data, in reverse order so that the next
/// input can be obtained via `pop()`.
pub(crate) rlp_prover_inputs: Vec<U256>,
Expand Down Expand Up @@ -375,6 +387,7 @@ impl<F: Field> GenerationState<F> {
registers: Default::default(),
memory: MemoryState::new(kernel_code),
traces: Traces::default(),
pruned_contexts: Vec::new(),
rlp_prover_inputs,
withdrawal_prover_inputs,
state_key_to_address: HashMap::new(),
Expand Down Expand Up @@ -462,6 +475,7 @@ impl<F: Field> GenerationState<F> {
registers: self.registers,
memory: self.memory.clone(),
traces: Traces::default(),
pruned_contexts: Vec::new(),
rlp_prover_inputs: self.rlp_prover_inputs.clone(),
state_key_to_address: self.state_key_to_address.clone(),
bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(),
Expand Down
19 changes: 18 additions & 1 deletion evm_arithmetization/src/memory/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,25 @@ pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1;
// Contains `next_segment * addr_changed * next_is_read`.
pub(crate) const INITIALIZE_AUX: usize = VIRTUAL_FIRST_CHANGE + 1;

// Contains `row_index` if and only if context `row_index` must be pruned,
// and zero if not.
pub(crate) const PRUNED_CONTEXTS: usize = INITIALIZE_AUX + 1;

// Pseudo-inverse of `PRUNED_CONTEXTS`. Used to ascertain it's nonzero.
pub(crate) const PRUNED_CONTEXTS_INV: usize = PRUNED_CONTEXTS + 1;

// Used for the context pruning lookup.
pub(crate) const PRUNED_CONTEXTS_FREQUENCIES: usize = PRUNED_CONTEXTS_INV + 1;

// Flag indicating whether the row should be pruned, i.e. whether its
// `ADDR_CONTEXT` is in `PRUNED_CONTEXTS`.
pub(crate) const IS_PRUNED: usize = PRUNED_CONTEXTS_FREQUENCIES + 1;

// Filter for the `MemAfter` CTL.
pub(crate) const MEM_AFTER_FILTER: usize = IS_PRUNED + 1;

// We use a range check to enforce the ordering.
pub(crate) const RANGE_CHECK: usize = INITIALIZE_AUX + 1;
pub(crate) const RANGE_CHECK: usize = MEM_AFTER_FILTER + 1;
/// The counter column (used for the range check) starts from 0 and increments.
pub(crate) const COUNTER: usize = RANGE_CHECK + 1;
/// The frequencies column used in logUp.
Expand Down
Loading