Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move zkasm generation into zkasm_codegen.rs #216

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cranelift/filetests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod match_directive;
mod runner;
mod runone;
mod subtest;
mod zkasm_codegen;
mod zkasm_runner;

mod test_alias_analysis;
Expand Down
234 changes: 3 additions & 231 deletions cranelift/filetests/src/test_zkasm.rs
Original file line number Diff line number Diff line change
@@ -1,248 +1,20 @@
#[cfg(test)]
mod tests {
use crate::zkasm_codegen;
use std::collections::HashMap;
use std::path::{Path, PathBuf};

use regex::Regex;
use std::fs::read_to_string;
use std::io::Error;

use cranelift_codegen::entity::EntityRef;
use cranelift_codegen::ir::function::FunctionParameters;
use cranelift_codegen::ir::ExternalName;
use cranelift_codegen::isa::zkasm;
use cranelift_codegen::{settings, FinalizedMachReloc, FinalizedRelocTarget};
use cranelift_wasm::{translate_module, ZkasmEnvironment};

use walkdir::WalkDir;
use wasmtime::*;

fn setup() {
let _ = env_logger::builder().is_test(true).try_init();
}

fn generate_preamble(
start_func_index: usize,
globals: &[(cranelift_wasm::GlobalIndex, cranelift_wasm::GlobalInit)],
data_segments: &[(u64, Vec<u8>)],
) -> Vec<String> {
let mut program: Vec<String> = Vec::new();

// Generate global variable definitions.
for (key, _) in globals {
program.push(format!("VAR GLOBAL global_{}", key.index()));
}

program.push("start:".to_string());
for (key, init) in globals {
match init {
cranelift_wasm::GlobalInit::I32Const(v) => {
// ZKASM stores constants in 2-complement form, so we need a cast to unsigned.
program.push(format!(
" {} :MSTORE(global_{}) ;; Global32({})",
*v as u32,
key.index(),
v
));
}
cranelift_wasm::GlobalInit::I64Const(v) => {
// ZKASM stores constants in 2-complement form, so we need a cast to unsigned.
program.push(format!(
" {} :MSTORE(global_{}) ;; Global64({})",
*v as u64,
key.index(),
v
));
}
_ => unimplemented!("Global type is not supported"),
}
}

// Generate const data segments definitions.
for (offset, data) in data_segments {
program.push(format!(" {} => E", offset / 8));
// Each slot stores 8 consecutive u8 numbers, with earlier addresses stored in lower
// bits.
for (i, chunk) in data.chunks(8).enumerate() {
let mut chunk_data = 0u64;
for c in chunk.iter().rev() {
chunk_data <<= 8;
chunk_data |= *c as u64;
}
program.push(format!(" {chunk_data}n :MSTORE(MEM:E + {i})"));
}
}

// The total amount of stack available on ZKASM processor is 2^16 of 8-byte words.
// Stack memory is a separate region that is independent from the heap.
program.push(" 0xffff => SP".to_string());
program.push(" zkPC + 2 => RR".to_string());
program.push(format!(" :JMP(function_{})", start_func_index));
program.push(" :JMP(finalizeExecution)".to_string());
program
}

fn generate_postamble() -> Vec<String> {
let mut program: Vec<String> = Vec::new();
// In the prover, the program always runs for a fixed number of steps (e.g. 2^23), so we
// need an infinite loop at the end of the program to fill the execution trace to the
// expected number of steps.
// In the future we might need to put zero in all registers here.
program.push("finalizeExecution:".to_string());
program.push(" ${beforeLast()} :JMPN(finalizeExecution)".to_string());
program.push(" :JMP(start)".to_string());
program.push("INCLUDE \"helpers/2-exp.zkasm\"".to_string());
program
}

// TODO: Relocations should be generated by a linker and/or clift itself.
fn fix_relocs(
code_buffer: &mut Vec<u8>,
params: &FunctionParameters,
relocs: &[FinalizedMachReloc],
) {
let mut delta = 0i32;
for reloc in relocs {
let start = (reloc.offset as i32 + delta) as usize;
let mut pos = start;
while code_buffer[pos] != b'\n' {
pos += 1;
delta -= 1;
}

let code = if let FinalizedRelocTarget::ExternalName(ExternalName::User(name)) =
reloc.target
{
let name = &params.user_named_funcs()[name];
if name.index == 0 {
b" B :ASSERT".to_vec()
} else {
format!(" zkPC + 2 => RR\n :JMP(function_{})", name.index)
.as_bytes()
.to_vec()
}
} else {
b" UNKNOWN".to_vec()
};
delta += code.len() as i32;

code_buffer.splice(start..pos, code);
}
}

// TODO: Labels optimization already happens in `MachBuffer`, we need to find a way to leverage
// it.
/// Label name is formatted as follows: <label_name>_<function_id>_<label_id>
/// Function id is unique through whole program while label id is unique only
/// inside given function.
/// Label name must begin from label_.
fn optimize_labels(code: &[&str], func_index: usize) -> Vec<String> {
let mut label_definition: HashMap<String, usize> = HashMap::new();
let mut label_uses: HashMap<String, Vec<usize>> = HashMap::new();
let mut lines = Vec::new();
for (index, line) in code.iter().enumerate() {
let mut line = line.to_string();
if line.starts_with(&"label_") {
// Handles lines with a label marker, e.g.:
// <label_name>_XXX:
let index_begin = line.rfind("_").expect("Failed to parse label index") + 1;
let label_name: String = line[..line.len() - 1].to_string();
line.insert_str(index_begin - 1, &format!("_{}", func_index));
label_definition.insert(label_name, index);
} else if line.contains(&"label_") {
// Handles lines with a jump to label, e.g.:
// A : JMPNZ(<label_name>_XXX)
let pos = line.rfind(&"_").unwrap() + 1;
let label_name = line[line
.find("label_")
.expect(&format!("Error parsing label line '{}'", line))
..line
.rfind(")")
.expect(&format!("Error parsing label line '{}'", line))]
.to_string();
line.insert_str(pos - 1, &format!("_{}", func_index));
label_uses.entry(label_name).or_default().push(index);
}
lines.push(line);
}

let mut lines_to_delete = Vec::new();
for (label, label_line) in label_definition {
match label_uses.entry(label) {
std::collections::hash_map::Entry::Occupied(uses) => {
if uses.get().len() == 1 {
let use_line = uses.get()[0];
if use_line + 1 == label_line {
lines_to_delete.push(use_line);
lines_to_delete.push(label_line);
}
}
}
std::collections::hash_map::Entry::Vacant(_) => {
lines_to_delete.push(label_line);
}
}
}
lines_to_delete.sort();
lines_to_delete.reverse();
for index in lines_to_delete {
lines.remove(index);
}
lines
}

fn generate_zkasm(wasm_module: &[u8]) -> String {
let flag_builder = settings::builder();
let isa_builder = zkasm::isa_builder("zkasm-unknown-unknown".parse().unwrap());
let isa = isa_builder
.finish(settings::Flags::new(flag_builder))
.unwrap();
let mut zkasm_environ = ZkasmEnvironment::new(isa.frontend_config());
translate_module(wasm_module, &mut zkasm_environ).unwrap();

let mut program: Vec<String> = Vec::new();

let start_func = zkasm_environ
.info
.start_func
.expect("Must have a start function");
// TODO: Preamble should be generated by a linker and/or clift itself.
program.append(&mut generate_preamble(
start_func.index(),
&zkasm_environ.info.global_inits,
&zkasm_environ.info.data_inits,
));

let num_func_imports = zkasm_environ.get_num_func_imports();
let mut context = cranelift_codegen::Context::new();
for (def_index, func) in zkasm_environ.info.function_bodies.iter() {
let func_index = num_func_imports + def_index.index();
program.push(format!("function_{}:", func_index));

let mut mem = vec![];
context.func = func.clone();
let compiled_code = context
.compile_and_emit(&*isa, &mut mem, &mut Default::default())
.unwrap();
let mut code_buffer = compiled_code.code_buffer().to_vec();
fix_relocs(
&mut code_buffer,
&func.params,
compiled_code.buffer.relocs(),
);

let code = std::str::from_utf8(&code_buffer).unwrap();
let lines: Vec<&str> = code.lines().collect();
let mut lines = optimize_labels(&lines, func_index);
program.append(&mut lines);

context.clear();
}

program.append(&mut generate_postamble());
program.join("\n")
}

fn run_wat_file(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let engine = Engine::default();
let binary = wat::parse_file(path)?;
Expand Down Expand Up @@ -287,7 +59,7 @@ mod tests {

fn test_module(name: &str) {
let module_binary = wat::parse_file(format!("../zkasm_data/{name}.wat")).unwrap();
let program = generate_zkasm(&module_binary);
let program = zkasm_codegen::generate_zkasm(&module_binary);
let expected =
expect_test::expect_file![format!("../../zkasm_data/generated/{name}.zkasm")];
expected.assert_eq(&program);
Expand Down Expand Up @@ -379,7 +151,7 @@ mod tests {
.join(path)
.join(format!("generated/{name}.zkasm"))];
let result = std::panic::catch_unwind(|| {
let program = generate_zkasm(&module_binary);
let program = zkasm_codegen::generate_zkasm(&module_binary);
expected.assert_eq(&program);
});
if let Err(err) = result {
Expand Down
Loading
Loading