diff --git a/cranelift/filetests/src/lib.rs b/cranelift/filetests/src/lib.rs index 726521a5d518..173736867647 100644 --- a/cranelift/filetests/src/lib.rs +++ b/cranelift/filetests/src/lib.rs @@ -22,6 +22,7 @@ mod match_directive; mod runner; mod runone; mod subtest; +mod zkasm_codegen; mod zkasm_runner; mod test_alias_analysis; diff --git a/cranelift/filetests/src/test_zkasm.rs b/cranelift/filetests/src/test_zkasm.rs index 0029b28a52e6..78753b546965 100644 --- a/cranelift/filetests/src/test_zkasm.rs +++ b/cranelift/filetests/src/test_zkasm.rs @@ -1,5 +1,6 @@ #[cfg(test)] mod tests { + use crate::zkasm_codegen; use std::collections::HashMap; use std::path::{Path, PathBuf}; @@ -7,13 +8,6 @@ mod tests { use std::fs::read_to_string; use std::io::Error; - use cranelift_codegen::entity::EntityRef; - use cranelift_codegen::ir::function::FunctionParameters; - use cranelift_codegen::ir::ExternalName; - use cranelift_codegen::isa::zkasm; - use cranelift_codegen::{settings, FinalizedMachReloc, FinalizedRelocTarget}; - use cranelift_wasm::{translate_module, ZkasmEnvironment}; - use walkdir::WalkDir; use wasmtime::*; @@ -21,228 +15,6 @@ mod tests { let _ = env_logger::builder().is_test(true).try_init(); } - fn generate_preamble( - start_func_index: usize, - globals: &[(cranelift_wasm::GlobalIndex, cranelift_wasm::GlobalInit)], - data_segments: &[(u64, Vec)], - ) -> Vec { - let mut program: Vec = Vec::new(); - - // Generate global variable definitions. - for (key, _) in globals { - program.push(format!("VAR GLOBAL global_{}", key.index())); - } - - program.push("start:".to_string()); - for (key, init) in globals { - match init { - cranelift_wasm::GlobalInit::I32Const(v) => { - // ZKASM stores constants in 2-complement form, so we need a cast to unsigned. - program.push(format!( - " {} :MSTORE(global_{}) ;; Global32({})", - *v as u32, - key.index(), - v - )); - } - cranelift_wasm::GlobalInit::I64Const(v) => { - // ZKASM stores constants in 2-complement form, so we need a cast to unsigned. - program.push(format!( - " {} :MSTORE(global_{}) ;; Global64({})", - *v as u64, - key.index(), - v - )); - } - _ => unimplemented!("Global type is not supported"), - } - } - - // Generate const data segments definitions. - for (offset, data) in data_segments { - program.push(format!(" {} => E", offset / 8)); - // Each slot stores 8 consecutive u8 numbers, with earlier addresses stored in lower - // bits. - for (i, chunk) in data.chunks(8).enumerate() { - let mut chunk_data = 0u64; - for c in chunk.iter().rev() { - chunk_data <<= 8; - chunk_data |= *c as u64; - } - program.push(format!(" {chunk_data}n :MSTORE(MEM:E + {i})")); - } - } - - // The total amount of stack available on ZKASM processor is 2^16 of 8-byte words. - // Stack memory is a separate region that is independent from the heap. - program.push(" 0xffff => SP".to_string()); - program.push(" zkPC + 2 => RR".to_string()); - program.push(format!(" :JMP(function_{})", start_func_index)); - program.push(" :JMP(finalizeExecution)".to_string()); - program - } - - fn generate_postamble() -> Vec { - let mut program: Vec = Vec::new(); - // In the prover, the program always runs for a fixed number of steps (e.g. 2^23), so we - // need an infinite loop at the end of the program to fill the execution trace to the - // expected number of steps. - // In the future we might need to put zero in all registers here. - program.push("finalizeExecution:".to_string()); - program.push(" ${beforeLast()} :JMPN(finalizeExecution)".to_string()); - program.push(" :JMP(start)".to_string()); - program.push("INCLUDE \"helpers/2-exp.zkasm\"".to_string()); - program - } - - // TODO: Relocations should be generated by a linker and/or clift itself. - fn fix_relocs( - code_buffer: &mut Vec, - params: &FunctionParameters, - relocs: &[FinalizedMachReloc], - ) { - let mut delta = 0i32; - for reloc in relocs { - let start = (reloc.offset as i32 + delta) as usize; - let mut pos = start; - while code_buffer[pos] != b'\n' { - pos += 1; - delta -= 1; - } - - let code = if let FinalizedRelocTarget::ExternalName(ExternalName::User(name)) = - reloc.target - { - let name = ¶ms.user_named_funcs()[name]; - if name.index == 0 { - b" B :ASSERT".to_vec() - } else { - format!(" zkPC + 2 => RR\n :JMP(function_{})", name.index) - .as_bytes() - .to_vec() - } - } else { - b" UNKNOWN".to_vec() - }; - delta += code.len() as i32; - - code_buffer.splice(start..pos, code); - } - } - - // TODO: Labels optimization already happens in `MachBuffer`, we need to find a way to leverage - // it. - /// Label name is formatted as follows: __ - /// Function id is unique through whole program while label id is unique only - /// inside given function. - /// Label name must begin from label_. - fn optimize_labels(code: &[&str], func_index: usize) -> Vec { - let mut label_definition: HashMap = HashMap::new(); - let mut label_uses: HashMap> = HashMap::new(); - let mut lines = Vec::new(); - for (index, line) in code.iter().enumerate() { - let mut line = line.to_string(); - if line.starts_with(&"label_") { - // Handles lines with a label marker, e.g.: - // _XXX: - let index_begin = line.rfind("_").expect("Failed to parse label index") + 1; - let label_name: String = line[..line.len() - 1].to_string(); - line.insert_str(index_begin - 1, &format!("_{}", func_index)); - label_definition.insert(label_name, index); - } else if line.contains(&"label_") { - // Handles lines with a jump to label, e.g.: - // A : JMPNZ(_XXX) - let pos = line.rfind(&"_").unwrap() + 1; - let label_name = line[line - .find("label_") - .expect(&format!("Error parsing label line '{}'", line)) - ..line - .rfind(")") - .expect(&format!("Error parsing label line '{}'", line))] - .to_string(); - line.insert_str(pos - 1, &format!("_{}", func_index)); - label_uses.entry(label_name).or_default().push(index); - } - lines.push(line); - } - - let mut lines_to_delete = Vec::new(); - for (label, label_line) in label_definition { - match label_uses.entry(label) { - std::collections::hash_map::Entry::Occupied(uses) => { - if uses.get().len() == 1 { - let use_line = uses.get()[0]; - if use_line + 1 == label_line { - lines_to_delete.push(use_line); - lines_to_delete.push(label_line); - } - } - } - std::collections::hash_map::Entry::Vacant(_) => { - lines_to_delete.push(label_line); - } - } - } - lines_to_delete.sort(); - lines_to_delete.reverse(); - for index in lines_to_delete { - lines.remove(index); - } - lines - } - - fn generate_zkasm(wasm_module: &[u8]) -> String { - let flag_builder = settings::builder(); - let isa_builder = zkasm::isa_builder("zkasm-unknown-unknown".parse().unwrap()); - let isa = isa_builder - .finish(settings::Flags::new(flag_builder)) - .unwrap(); - let mut zkasm_environ = ZkasmEnvironment::new(isa.frontend_config()); - translate_module(wasm_module, &mut zkasm_environ).unwrap(); - - let mut program: Vec = Vec::new(); - - let start_func = zkasm_environ - .info - .start_func - .expect("Must have a start function"); - // TODO: Preamble should be generated by a linker and/or clift itself. - program.append(&mut generate_preamble( - start_func.index(), - &zkasm_environ.info.global_inits, - &zkasm_environ.info.data_inits, - )); - - let num_func_imports = zkasm_environ.get_num_func_imports(); - let mut context = cranelift_codegen::Context::new(); - for (def_index, func) in zkasm_environ.info.function_bodies.iter() { - let func_index = num_func_imports + def_index.index(); - program.push(format!("function_{}:", func_index)); - - let mut mem = vec![]; - context.func = func.clone(); - let compiled_code = context - .compile_and_emit(&*isa, &mut mem, &mut Default::default()) - .unwrap(); - let mut code_buffer = compiled_code.code_buffer().to_vec(); - fix_relocs( - &mut code_buffer, - &func.params, - compiled_code.buffer.relocs(), - ); - - let code = std::str::from_utf8(&code_buffer).unwrap(); - let lines: Vec<&str> = code.lines().collect(); - let mut lines = optimize_labels(&lines, func_index); - program.append(&mut lines); - - context.clear(); - } - - program.append(&mut generate_postamble()); - program.join("\n") - } - fn run_wat_file(path: &Path) -> Result<(), Box> { let engine = Engine::default(); let binary = wat::parse_file(path)?; @@ -287,7 +59,7 @@ mod tests { fn test_module(name: &str) { let module_binary = wat::parse_file(format!("../zkasm_data/{name}.wat")).unwrap(); - let program = generate_zkasm(&module_binary); + let program = zkasm_codegen::generate_zkasm(&module_binary); let expected = expect_test::expect_file![format!("../../zkasm_data/generated/{name}.zkasm")]; expected.assert_eq(&program); @@ -379,7 +151,7 @@ mod tests { .join(path) .join(format!("generated/{name}.zkasm"))]; let result = std::panic::catch_unwind(|| { - let program = generate_zkasm(&module_binary); + let program = zkasm_codegen::generate_zkasm(&module_binary); expected.assert_eq(&program); }); if let Err(err) = result { diff --git a/cranelift/filetests/src/zkasm_codegen.rs b/cranelift/filetests/src/zkasm_codegen.rs new file mode 100644 index 000000000000..8514b5db905f --- /dev/null +++ b/cranelift/filetests/src/zkasm_codegen.rs @@ -0,0 +1,233 @@ +use cranelift_codegen::entity::EntityRef; +use cranelift_codegen::ir::function::FunctionParameters; +use cranelift_codegen::ir::ExternalName; +use cranelift_codegen::isa::zkasm; +use cranelift_codegen::{settings, FinalizedMachReloc, FinalizedRelocTarget}; +use cranelift_wasm::{translate_module, ZkasmEnvironment}; +use std::collections::HashMap; + +#[allow(dead_code)] +pub fn generate_zkasm(wasm_module: &[u8]) -> String { + let flag_builder = settings::builder(); + let isa_builder = zkasm::isa_builder("zkasm-unknown-unknown".parse().unwrap()); + let isa = isa_builder + .finish(settings::Flags::new(flag_builder)) + .unwrap(); + let mut zkasm_environ = ZkasmEnvironment::new(isa.frontend_config()); + translate_module(wasm_module, &mut zkasm_environ).unwrap(); + + let mut program: Vec = Vec::new(); + + let start_func = zkasm_environ + .info + .start_func + .expect("Must have a start function"); + // TODO: Preamble should be generated by a linker and/or clift itself. + program.append(&mut generate_preamble( + start_func.index(), + &zkasm_environ.info.global_inits, + &zkasm_environ.info.data_inits, + )); + + let num_func_imports = zkasm_environ.get_num_func_imports(); + let mut context = cranelift_codegen::Context::new(); + for (def_index, func) in zkasm_environ.info.function_bodies.iter() { + let func_index = num_func_imports + def_index.index(); + program.push(format!("function_{}:", func_index)); + + let mut mem = vec![]; + context.func = func.clone(); + let compiled_code = context + .compile_and_emit(&*isa, &mut mem, &mut Default::default()) + .unwrap(); + let mut code_buffer = compiled_code.code_buffer().to_vec(); + fix_relocs( + &mut code_buffer, + &func.params, + compiled_code.buffer.relocs(), + ); + + let code = std::str::from_utf8(&code_buffer).unwrap(); + let lines: Vec<&str> = code.lines().collect(); + let mut lines = optimize_labels(&lines, func_index); + program.append(&mut lines); + + context.clear(); + } + + program.append(&mut generate_postamble()); + program.join("\n") +} + +#[allow(dead_code)] +pub fn generate_preamble( + start_func_index: usize, + globals: &[(cranelift_wasm::GlobalIndex, cranelift_wasm::GlobalInit)], + data_segments: &[(u64, Vec)], +) -> Vec { + let mut program: Vec = Vec::new(); + + // Generate global variable definitions. + for (key, _) in globals { + program.push(format!("VAR GLOBAL global_{}", key.index())); + } + + program.push("start:".to_string()); + for (key, init) in globals { + match init { + cranelift_wasm::GlobalInit::I32Const(v) => { + // ZKASM stores constants in 2-complement form, so we need a cast to unsigned. + program.push(format!( + " {} :MSTORE(global_{}) ;; Global32({})", + *v as u32, + key.index(), + v + )); + } + cranelift_wasm::GlobalInit::I64Const(v) => { + // ZKASM stores constants in 2-complement form, so we need a cast to unsigned. + program.push(format!( + " {} :MSTORE(global_{}) ;; Global64({})", + *v as u64, + key.index(), + v + )); + } + _ => unimplemented!("Global type is not supported"), + } + } + + // Generate const data segments definitions. + for (offset, data) in data_segments { + program.push(format!(" {} => E", offset / 8)); + // Each slot stores 8 consecutive u8 numbers, with earlier addresses stored in lower + // bits. + for (i, chunk) in data.chunks(8).enumerate() { + let mut chunk_data = 0u64; + for c in chunk.iter().rev() { + chunk_data <<= 8; + chunk_data |= *c as u64; + } + program.push(format!(" {chunk_data}n :MSTORE(MEM:E + {i})")); + } + } + + // The total amount of stack available on ZKASM processor is 2^16 of 8-byte words. + // Stack memory is a separate region that is independent from the heap. + program.push(" 0xffff => SP".to_string()); + program.push(" zkPC + 2 => RR".to_string()); + program.push(format!(" :JMP(function_{})", start_func_index)); + program.push(" :JMP(finalizeExecution)".to_string()); + program +} + +#[allow(dead_code)] +fn generate_postamble() -> Vec { + let mut program: Vec = Vec::new(); + // In the prover, the program always runs for a fixed number of steps (e.g. 2^23), so we + // need an infinite loop at the end of the program to fill the execution trace to the + // expected number of steps. + // In the future we might need to put zero in all registers here. + program.push("finalizeExecution:".to_string()); + program.push(" ${beforeLast()} :JMPN(finalizeExecution)".to_string()); + program.push(" :JMP(start)".to_string()); + program.push("INCLUDE \"helpers/2-exp.zkasm\"".to_string()); + program +} + +// TODO: Relocations should be generated by a linker and/or clift itself. +#[allow(dead_code)] +fn fix_relocs( + code_buffer: &mut Vec, + params: &FunctionParameters, + relocs: &[FinalizedMachReloc], +) { + let mut delta = 0i32; + for reloc in relocs { + let start = (reloc.offset as i32 + delta) as usize; + let mut pos = start; + while code_buffer[pos] != b'\n' { + pos += 1; + delta -= 1; + } + + let code = + if let FinalizedRelocTarget::ExternalName(ExternalName::User(name)) = reloc.target { + let name = ¶ms.user_named_funcs()[name]; + if name.index == 0 { + b" B :ASSERT".to_vec() + } else { + format!(" zkPC + 2 => RR\n :JMP(function_{})", name.index) + .as_bytes() + .to_vec() + } + } else { + b" UNKNOWN".to_vec() + }; + delta += code.len() as i32; + + code_buffer.splice(start..pos, code); + } +} + +// TODO: Labels optimization already happens in `MachBuffer`, we need to find a way to leverage +// it. +/// Label name is formatted as follows: __ +/// Function id is unique through whole program while label id is unique only +/// inside given function. +/// Label name must begin from label_. +#[allow(dead_code)] +fn optimize_labels(code: &[&str], func_index: usize) -> Vec { + let mut label_definition: HashMap = HashMap::new(); + let mut label_uses: HashMap> = HashMap::new(); + let mut lines = Vec::new(); + for (index, line) in code.iter().enumerate() { + let mut line = line.to_string(); + if line.starts_with(&"label_") { + // Handles lines with a label marker, e.g.: + // _XXX: + let index_begin = line.rfind("_").expect("Failed to parse label index") + 1; + let label_name: String = line[..line.len() - 1].to_string(); + line.insert_str(index_begin - 1, &format!("_{}", func_index)); + label_definition.insert(label_name, index); + } else if line.contains(&"label_") { + // Handles lines with a jump to label, e.g.: + // A : JMPNZ(_XXX) + let pos = line.rfind(&"_").unwrap() + 1; + let label_name = line[line + .find("label_") + .expect(&format!("Error parsing label line '{}'", line)) + ..line + .rfind(")") + .expect(&format!("Error parsing label line '{}'", line))] + .to_string(); + line.insert_str(pos - 1, &format!("_{}", func_index)); + label_uses.entry(label_name).or_default().push(index); + } + lines.push(line); + } + + let mut lines_to_delete = Vec::new(); + for (label, label_line) in label_definition { + match label_uses.entry(label) { + std::collections::hash_map::Entry::Occupied(uses) => { + if uses.get().len() == 1 { + let use_line = uses.get()[0]; + if use_line + 1 == label_line { + lines_to_delete.push(use_line); + lines_to_delete.push(label_line); + } + } + } + std::collections::hash_map::Entry::Vacant(_) => { + lines_to_delete.push(label_line); + } + } + } + lines_to_delete.sort(); + lines_to_delete.reverse(); + for index in lines_to_delete { + lines.remove(index); + } + lines +}