Skip to content

Commit

Permalink
Implement full memory load and stores
Browse files Browse the repository at this point in the history
This PR adds the remaining features necessary to run SHA256 benchmark:

    Misaligned loads and stores
    Loads and stores of size < 64 bit

There instructions are generated by rustc compiler for wasm32 target and are quite common.

The current implementation is not verifiable as it uses external (free) inputs. It turned out to be quite challenging to implement them in a verifiable fashion because they involve many bitwise operations which are quite expensive.

For now, there is value in having a non-verifiable implementation that makes sure we can run tests and benchmarks related to memory. I've linked a few TODOs to do a verifiable implementation, but first, we would need a proper design for that part. Most likely, we will need to modify the zkAsm processor to support these operations efficiently.

Implementation-wise, there are three steps:

    Conversion from address + offset to slot + offset
    Read/write the value at the correct offset
    Narrowing down the value to the desired type width
  • Loading branch information
aborg-dev committed Jan 30, 2024
1 parent 0ae4aa8 commit 2484725
Show file tree
Hide file tree
Showing 11 changed files with 462 additions and 101 deletions.
23 changes: 23 additions & 0 deletions cranelift/codegen/src/isa/zkasm/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,19 @@ impl LoadOP {
Self::Fld => 0b011,
}
}
pub(crate) fn width(self) -> u32 {
match self {
Self::I8 => 1,
Self::I16 => 2,
Self::I32 => 4,
Self::U32 => 4,
Self::U8 => 1,
Self::U16 => 2,
Self::U64 => 8,
Self::Flw => unimplemented!(),
Self::Fld => unimplemented!(),
}
}
}

impl StoreOP {
Expand Down Expand Up @@ -843,6 +856,16 @@ impl StoreOP {
Self::Fsd => 0b011,
}
}
pub(crate) fn width(self) -> u32 {
match self {
Self::I8 => 1,
Self::I16 => 2,
Self::I32 => 4,
Self::I64 => 8,
Self::Fsw => unimplemented!(),
Self::Fsd => unimplemented!(),
}
}
}

impl IntSelectOP {
Expand Down
129 changes: 126 additions & 3 deletions cranelift/codegen/src/isa/zkasm/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,14 +591,74 @@ impl MachInstEmit for Inst {
match from {
AMode::RegOffset(r, ..) => {
debug_assert_eq!(r, e0());
// TODO(#43): Implement the conversion using verifiable computations.
put_string(
&format!(
"${{ ({}) / 8 }} => {}\n",
access_reg_with_offset(r, offset),
reg_name(r)
),
sink,
);
let rem = offset % 8;
let width = op.width() as i64;

put_string(
&format!(
"$ => {} :MLOAD(MEM:{})\n",
reg_name(rd.to_reg()),
access_reg_with_offset(r, offset)
reg_name(r)
),
sink,
);
// TODO(#34): Implement unaligned and narrow reads using verifiable
// computations.
if rem > 0 {
put_string(
&format!(
"${{ {} >> {} }} => {}\n",
reg_name(rd.to_reg()),
8 * rem,
reg_name(rd.to_reg())
),
sink,
);
}

// Handle the case when read spans two slots.
if rem + width > 8 {
let rem = rem + width - 8;
put_string(
&format!(
"$ => {} :MLOAD(MEM:{})\n",
reg_name(d0()),
access_reg_with_offset(r, 1)
),
sink,
);
put_string(
&format!(
"${{ (D << {}) | {} }} => {}\n",
64 - 8 * rem,
reg_name(rd.to_reg()),
reg_name(rd.to_reg()),
),
sink,
);
}

// Mask the value to the width of the resulting type.
if width < 8 {
put_string(
&format!(
"${{ {} & ((1 << {}) - 1) }} => {}\n",
reg_name(rd.to_reg()),
8 * width,
reg_name(rd.to_reg()),
),
sink,
);
}
}
AMode::SPOffset(..) | AMode::NominalSPOffset(..) | AMode::FPOffset(..) => {
assert_eq!(offset % 8, 0);
Expand Down Expand Up @@ -629,14 +689,77 @@ impl MachInstEmit for Inst {
match to {
AMode::RegOffset(r, ..) => {
debug_assert_eq!(r, e0());
// TODO(#43): Implement the conversion using verifiable computations.
put_string(
&format!(
"{} :MSTORE(MEM:{})\n",
"${{ ({}) / 8 }} => {}\n",
access_reg_with_offset(r, offset),
reg_name(r)
),
sink,
);
let rem = offset % 8;
let width = op.width() as i64;

if op == StoreOP::I64 && rem == 0 {
put_string(
&format!("{} :MSTORE(MEM:{})\n", reg_name(src), reg_name(r)),
sink,
);
return;
}

// TODO(#34): Implement unaligned and narrow writes using verifiable
// computations.
if width < 8 {
put_string(
&format!(
"${{ {} & ((1 << {}) - 1) }} => {}\n",
reg_name(src),
8 * width,
reg_name(src),
),
sink,
);
}
put_string(
&format!("$ => {} :MLOAD(MEM:{})\n", reg_name(d0()), reg_name(r)),
sink,
);
put_string(
&format!(
"${{ (D & ~(((1 << {}) - 1) << {})) | ({} << {}) }} :MSTORE(MEM:{})\n",
8 * width,
8 * rem,
reg_name(src),
access_reg_with_offset(r, offset)
8 * rem,
reg_name(r),
),
sink,
);

// Handle the case when write spans two slots.
if rem + width > 8 {
let rem = rem + width - 8;
put_string(
&format!(
"$ => {} :MLOAD(MEM:{})\n",
reg_name(d0()),
access_reg_with_offset(r, 1)
),
sink,
);
put_string(
&format!(
"${{ (D & ~((1 << {}) - 1)) | ({} & ((1 << {}) - 1)) }} :MSTORE(MEM:{})\n",
8 * rem,
reg_name(src),
8 * rem,
access_reg_with_offset(r, 1),
),
sink,
);
}
}
AMode::SPOffset(..) | AMode::NominalSPOffset(..) | AMode::FPOffset(..) => {
assert_eq!(offset % 8, 0);
Expand Down
8 changes: 7 additions & 1 deletion cranelift/codegen/src/isa/zkasm/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,19 @@ fn zkasm_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandC
if let Some(r) = from.get_allocatable_register() {
collector.reg_fixed_use(r, e0());
}
let mut clobbered = PRegSet::empty();
clobbered.add(d0().to_real_reg().unwrap().into());
collector.reg_clobbers(clobbered);
collector.reg_def(rd);
}
&Inst::Store { to, src, .. } => {
if let Some(r) = to.get_allocatable_register() {
collector.reg_fixed_use(r, e0());
}
collector.reg_use(src);
let mut clobbered = PRegSet::empty();
clobbered.add(d0().to_real_reg().unwrap().into());
collector.reg_clobbers(clobbered);
collector.reg_late_use(src);
}
&Inst::Args { ref args } => {
for arg in args {
Expand Down
2 changes: 1 addition & 1 deletion cranelift/filetests/src/test_zkasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ mod tests {

// Generate const data segments definitions.
for (offset, data) in data_segments {
program.push(format!(" {offset} => E"));
program.push(format!(" {} => E", offset / 8));
// Each slot stores 8 consecutive u8 numbers, with earlier addresses stored in lower
// bits.
for (i, chunk) in data.chunks(8).enumerate() {
Expand Down
Loading

0 comments on commit 2484725

Please sign in to comment.