From 64919611a0e39593b8803befb0fe11dbfcbd948f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Fri, 1 Mar 2024 20:17:51 +0100 Subject: [PATCH 01/28] First WAM instructions --- Cargo.lock | 228 +++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 5 +- src/machine/compile.rs | 2 + src/machine/mod.rs | 3 + 4 files changed, 237 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 225cb267f..72bbd9f62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "anyhow" +version = "1.0.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" + +[[package]] +name = "arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" + [[package]] name = "arrayvec" version = "0.5.2" @@ -480,6 +492,136 @@ dependencies = [ "libc", ] +[[package]] +name = "cranelift" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71ad1c9f3eb965e6f7af65524ca625f2800c7951a213ee7336e808dd42808262" +dependencies = [ + "cranelift-codegen", + "cranelift-frontend", +] + +[[package]] +name = "cranelift-bforest" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ec790dbba3970f5dc1fb615e81167adbe90a81b6d5ce53d1d7f97d1da0c816" +dependencies = [ + "cranelift-entity", +] + +[[package]] +name = "cranelift-codegen" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aff1e6625920ec73a0a222bf8f9e5b7fc6a765ccef47e0a010fe327f07eb430a" +dependencies = [ + "bumpalo", + "cranelift-bforest", + "cranelift-codegen-meta", + "cranelift-codegen-shared", + "cranelift-control", + "cranelift-entity", + "cranelift-isle", + "gimli", + "hashbrown 0.14.3", + "log", + "regalloc2", + "smallvec", + "target-lexicon", +] + +[[package]] +name = "cranelift-codegen-meta" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1b92f1d526daa2b1878f0171d3a216a70e3f05d2fe6786a5469299c6919a0a" +dependencies = [ + "cranelift-codegen-shared", +] + +[[package]] +name = "cranelift-codegen-shared" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3333fb7671c8b6a8d24fd05dced0a4940773d9cdddeeb38b1f99260eba2c50e9" + +[[package]] +name = "cranelift-control" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf07b47de593bc43b7bc25a66c509469fd02c2ad850833c6a2dff45e81980c3b" +dependencies = [ + "arbitrary", +] + +[[package]] +name = "cranelift-entity" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfd988176a86d6985dd75ac0762033b36d134c4af8b806b605e8542c489fe1c2" + +[[package]] +name = "cranelift-frontend" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "214f216d2288f695bf0e0c2775d4e58759cd6ad5a118852fb5cc26d817e3cdcc" +dependencies = [ + "cranelift-codegen", + "log", + "smallvec", + "target-lexicon", +] + +[[package]] +name = "cranelift-isle" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaf4dd3bc4bd4d594180a4db16b0b411a9c3a6757a4d13c5ede1177a0131fc80" + +[[package]] +name = "cranelift-jit" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52bb8f1189dd81ed989b2cbd62371d67d35afb8977239d4a917a05b7e22e0692" +dependencies = [ + "anyhow", + "cranelift-codegen", + "cranelift-control", + "cranelift-entity", + "cranelift-module", + "cranelift-native", + "libc", + "log", + "region", + "target-lexicon", + "wasmtime-jit-icache-coherence", + "windows-sys 0.52.0", +] + +[[package]] +name = "cranelift-module" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ecedd1fec01db9760f94dc2abca90975ef5e90746be5aa0041dbba99268ac23" +dependencies = [ + "anyhow", + "cranelift-codegen", + "cranelift-control", +] + +[[package]] +name = "cranelift-native" +version = "0.105.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ddcb8ffe184e9aaf941387f9bcd016236fe1bedc29a84af0771d98443a71aea" +dependencies = [ + "cranelift-codegen", + "libc", + "target-lexicon", +] + [[package]] name = "criterion" version = "0.5.1" @@ -831,6 +973,12 @@ dependencies = [ "str-buf", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + [[package]] name = "fastrand" version = "2.0.1" @@ -1053,6 +1201,11 @@ name = "gimli" version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +dependencies = [ + "fallible-iterator", + "indexmap 2.1.0", + "stable_deref_trait", +] [[package]] name = "git-version" @@ -1111,6 +1264,15 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -1539,6 +1701,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" +[[package]] +name = "mach" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" +dependencies = [ + "libc", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2312,6 +2483,19 @@ version = "0.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d813022b2e00774a48eaf43caaa3c20b45f040ba8cbf398e2e8911a06668dbe6" +[[package]] +name = "regalloc2" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad156d539c879b7a24a363a2016d77961786e71f48f2e2fc8302a92abd2429a6" +dependencies = [ + "hashbrown 0.13.2", + "log", + "rustc-hash", + "slice-group-by", + "smallvec", +] + [[package]] name = "regex" version = "1.10.2" @@ -2347,6 +2531,18 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "region" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877e54ea2adcd70d80e9179344c97f93ef0dffd6b03e1f4529e6e83ab2fa9ae0" +dependencies = [ + "bitflags 1.3.2", + "libc", + "mach", + "winapi", +] + [[package]] name = "reqwest" version = "0.11.23" @@ -2449,6 +2645,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustix" version = "0.38.28" @@ -2561,6 +2763,9 @@ dependencies = [ "chrono", "console_error_panic_hook", "cpu-time", + "cranelift", + "cranelift-jit", + "cranelift-module", "criterion", "crossterm", "crrl", @@ -2855,6 +3060,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slice-group-by" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "826167069c09b99d56f31e9ae5c99049e932a98c9dc2dac47645b08dbbf76ba7" + [[package]] name = "smallvec" version = "1.11.2" @@ -3067,6 +3278,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "target-lexicon" +version = "0.12.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" + [[package]] name = "tempfile" version = "3.9.0" @@ -3573,6 +3790,17 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +[[package]] +name = "wasmtime-jit-icache-coherence" +version = "18.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c87783a70d3b7602834118f42e73ed8979f1b75a01e0fc4bf311cc6dc31f8fc" +dependencies = [ + "cfg-if", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "web-sys" version = "0.3.66" diff --git a/Cargo.toml b/Cargo.toml index 3e06970e0..ccc7b36d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ build = "build/main.rs" rust-version = "1.70" [lib] -crate-type = ["cdylib", "rlib"] +crate-type = ["cdylib", "rlib", "staticlib"] [features] default = ["ffi", "repl", "hostname", "tls", "http", "crypto-full"] @@ -83,6 +83,9 @@ native-tls = { version = "0.2.4", optional = true } warp = { version = "=0.3.5", features = ["tls"], optional = true } reqwest = { version = "0.11.18", optional = true } tokio = { version = "1.28.2", features = ["full"] } +cranelift = "0.105.0" +cranelift-jit = "0.105.0" +cranelift-module = "0.105.0" [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.10", features = ["js"] } diff --git a/src/machine/compile.rs b/src/machine/compile.rs index 7f967e8bd..6eaed8aad 100644 --- a/src/machine/compile.rs +++ b/src/machine/compile.rs @@ -1355,6 +1355,8 @@ impl<'a, LS: LoadState<'a>> Loader<'a, LS> { index_ptr, ); + dbg!(&code); + self.wam_prelude.code.extend(code); Ok(code_index) } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index 5326c83cf..f67315a54 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -12,6 +12,7 @@ pub mod disjuncts; pub mod dispatch; pub mod gc; pub mod heap; +pub mod jit; pub mod lib_machine; pub mod load_state; pub mod machine_errors; @@ -1140,6 +1141,8 @@ impl Machine { Ok(()) } + + #[inline(always)] fn try_execute(&mut self, name: Atom, arity: usize, idx: IndexPtr) -> CallResult { let compiled_tl_index = idx.p() as usize; From 1775efba447b4dbda2cb742a8712260bd17bf219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Fri, 1 Mar 2024 20:18:19 +0100 Subject: [PATCH 02/28] JIT Rust file --- src/machine/jit.rs | 198 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 src/machine/jit.rs diff --git a/src/machine/jit.rs b/src/machine/jit.rs new file mode 100644 index 000000000..d69bf1af3 --- /dev/null +++ b/src/machine/jit.rs @@ -0,0 +1,198 @@ +use std::collections::HashMap; + +use crate::instructions::*; +use crate::machine::*; + +use cranelift::prelude::*; +use cranelift::prelude::codegen::ir::immediates::Offset32; +use cranelift_jit::{JITBuilder, JITModule}; +use cranelift_module::{Linkage, Module}; + +struct CompileOutput { + module: JITModule, + code_ptr: *const u8, +} + +#[derive(Debug, PartialEq)] +enum JitCompileError { + UndefinedPredicate, + InstructionNotImplemented, +} + +struct JitMachine { + modules: HashMap, + machine_state: *const u8, + registers: *const u8, + write_literal_to_var: *const u8, +} + +impl JitMachine { + pub fn new(machine_st: &MachineState) -> Self { + + JitMachine { + modules: HashMap::new(), + machine_state: machine_st as *const MachineState as *const u8, + registers: machine_st.registers.as_ptr() as *const u8, + write_literal_to_var: MachineState::write_literal_to_var as *const u8, + } + } + + // For now, one module = one predicate + // Functions must take N parameters (arity) + // Access to MachineState via global pointer + // MachineState Registers + ShadowRegisters?? + // Use TAIL call convention + pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError>{ + let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true")], cranelift_module::default_libcall_names()).unwrap(); + builder.symbols(self.modules.iter().map(|(k, v)| (k,v.code_ptr))); + + let mut module = JITModule::new(builder); + let mut ctx = module.make_context(); + let mut func_ctx = FunctionBuilderContext::new(); + + let mut sig = module.make_signature(); + // Set arguments/returns + sig.call_conv = isa::CallConv::Tail; + ctx.func.signature = sig.clone(); + + let mut func = module.declare_function(name, Linkage::Local, &sig).unwrap(); + + let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); + let block = fn_builder.create_block(); + fn_builder.switch_to_block(block); + for wam_instr in code { + match wam_instr { + Instruction::Proceed => { + fn_builder.ins().return_(&[]); + fn_builder.seal_all_blocks(); + fn_builder.finalize(); + break; + } + Instruction::ExecuteNamed(arity, pred_name, ..) => { + let mut callee_func_sig = module.make_signature(); + callee_func_sig.call_conv = isa::CallConv::Tail; + // right now, all predicates have 0 arguments. In the future with shadow registers, we could improve this + if let Ok(callee_func) = module.declare_function(&format!("{}/{}", pred_name.as_str(), arity), Linkage::Import, &callee_func_sig) { + let func_ref = module.declare_func_in_func(callee_func, fn_builder.func); + fn_builder.ins().return_call(func_ref, &[]); + fn_builder.seal_all_blocks(); + fn_builder.finalize(); + break; + } else { + return Err(JitCompileError::UndefinedPredicate); + } + } + Instruction::GetConstant(_, c, reg) => { + let reg_ptr = fn_builder.ins().iconst(types::I64, self.registers as i64); // TODO: call deref + let reg_num = reg.reg_num(); + let reg_value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((reg_num as i32)*8)); + let c = unsafe { std::mem::transmute::(u64::from(c)) }; + let c_value = fn_builder.ins().iconst(types::I64, c); + let machine_state_value = fn_builder.ins().iconst(types::I64, self.machine_state as i64); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let write_literal_to_var = fn_builder.ins().iconst(types::I64, self.write_literal_to_var as i64); + fn_builder.ins().call_indirect(sig_ref, write_literal_to_var, &[machine_state_value, reg_value, c_value]); + } + Instruction::PutConstant(_, c, reg) => { + let reg_ptr = fn_builder.ins().iconst(types::I64, self.registers as i64); + let reg_num = reg.reg_num(); + let c = unsafe { std::mem::transmute::(u64::from(c)) }; + let c_value = fn_builder.ins().iconst(types::I64, c); + fn_builder.ins().store(MemFlags::new(), c_value, reg_ptr, Offset32::new((reg_num as i32)*8)); + } + _ => { + return Err(JitCompileError::InstructionNotImplemented); + } + } + } + module.define_function(func, &mut ctx).unwrap(); + module.clear_context(&mut ctx); + + module.finalize_definitions().unwrap(); + + let code_ptr = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; + self.modules.insert(name.to_string(), CompileOutput { + module, + code_ptr + }); + + Ok(()) + } + + pub fn exec(&self, name: &str) { + if let Some(output) = self.modules.get(name) { + let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn() -> ()>(output.code_ptr) }; + code_ptr(); + } else { + panic!("Can't find function"); + } + } +} + +// basic. +#[test] +fn jit_test_proceed() { + let machine_st = MachineState::new(); + let code = vec![Instruction::Proceed]; + let name = "basic/0"; + + let mut jit = JitMachine::new(&machine_st); + assert_eq!(jit.compile(name, code), Ok(())); + jit.exec(name); +} + +// basic. +// simple :- basic. +#[test] +fn jit_test_execute_named() { + let machine_st = MachineState::new(); + let mut jit = JitMachine::new(&machine_st); + let code = vec![Instruction::Proceed]; + let name = "basic/0"; + assert_eq!(jit.compile(name, code), Ok(())); + + let code = vec![Instruction::ExecuteNamed(0, atom!("basic"), CodeIndex::default(&mut Arena::new()))]; + let name = "simple/0"; + assert_eq!(jit.compile(name, code), Ok(())); + jit.exec(name); +} + +// a(5). +// b :- a(5). +#[test] +fn jit_test_get_constant() { + let machine_st = MachineState::new(); + let mut jit = JitMachine::new(&machine_st); + let code = vec![Instruction::GetConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::Proceed]; + let name = "a/1"; + assert_eq!(jit.compile(name, code), Ok(())); + + let code = vec![Instruction::PutConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut Arena::new()))]; + let name = "b/0"; + assert_eq!(jit.compile(name, code), Ok(())); + jit.exec(name); + assert_eq!(machine_st.fail, false); +} + +// a(5). +// b :- a(6). +#[test] +fn jit_test_get_constant_fail() { + let machine_st = MachineState::new(); + let machine_st = Box::pin(MachineState::new()); + let mut jit = JitMachine::new(&machine_st); + let code = vec![Instruction::GetConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::Proceed]; + let name = "a/1"; + assert_eq!(jit.compile(name, code), Ok(())); + + let code = vec![Instruction::PutConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(6)), RegType::Temp(1)), Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut Arena::new()))]; + let name = "b/0"; + assert_eq!(jit.compile(name, code), Ok(())); + jit.exec(name); + assert_eq!(machine_st.fail, true); +} From f8345f8765a798351dd811607d81d13ec2cd3273 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 19 Mar 2024 00:15:42 +0100 Subject: [PATCH 03/28] WIP jit_compile/1 --- Cargo.lock | 56 ++++++++++++------------- Cargo.toml | 6 +-- build/instructions_template.rs | 4 ++ src/lib/jit.pl | 11 +++++ src/machine/compile.rs | 2 - src/machine/dispatch.rs | 8 ++++ src/machine/jit.rs | 15 +++++-- src/machine/mod.rs | 18 +++++++++ src/machine/system_calls.rs | 74 ++++++++++++++++++++++++++++++++++ 9 files changed, 157 insertions(+), 37 deletions(-) create mode 100644 src/lib/jit.pl diff --git a/Cargo.lock b/Cargo.lock index 72bbd9f62..c7eb71793 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.6" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "getrandom", @@ -494,9 +494,9 @@ dependencies = [ [[package]] name = "cranelift" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71ad1c9f3eb965e6f7af65524ca625f2800c7951a213ee7336e808dd42808262" +checksum = "6d8964676f9a0165edbdb7650344acb73d4decd590685f0964b6109b3691fff8" dependencies = [ "cranelift-codegen", "cranelift-frontend", @@ -504,18 +504,18 @@ dependencies = [ [[package]] name = "cranelift-bforest" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ec790dbba3970f5dc1fb615e81167adbe90a81b6d5ce53d1d7f97d1da0c816" +checksum = "16d5521e2abca66bbb1ddeecbb6f6965c79160352ae1579b39f8c86183895c24" dependencies = [ "cranelift-entity", ] [[package]] name = "cranelift-codegen" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aff1e6625920ec73a0a222bf8f9e5b7fc6a765ccef47e0a010fe327f07eb430a" +checksum = "ef40a4338a47506e832ac3e53f7f1375bc59351f049a8379ff736dd02565bd95" dependencies = [ "bumpalo", "cranelift-bforest", @@ -534,39 +534,39 @@ dependencies = [ [[package]] name = "cranelift-codegen-meta" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1b92f1d526daa2b1878f0171d3a216a70e3f05d2fe6786a5469299c6919a0a" +checksum = "d24cd5d85985c070f73dfca07521d09086362d1590105ba44b0932bf33513b61" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3333fb7671c8b6a8d24fd05dced0a4940773d9cdddeeb38b1f99260eba2c50e9" +checksum = "e0584c4363e3aa0a3c7cb98a778fbd5326a3709f117849a727da081d4051726c" [[package]] name = "cranelift-control" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf07b47de593bc43b7bc25a66c509469fd02c2ad850833c6a2dff45e81980c3b" +checksum = "f25ecede098c6553fdba362a8e4c9ecb8d40138363bff47f9712db75be7f0571" dependencies = [ "arbitrary", ] [[package]] name = "cranelift-entity" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfd988176a86d6985dd75ac0762033b36d134c4af8b806b605e8542c489fe1c2" +checksum = "6ea081a42f25dc4c5b248b87efdd87dcd3842a1050a37524ec5391e6172058cb" [[package]] name = "cranelift-frontend" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214f216d2288f695bf0e0c2775d4e58759cd6ad5a118852fb5cc26d817e3cdcc" +checksum = "9796e712f5af797e247784f7518e6b0a83a8907d73d51526982d86ecb3a58b68" dependencies = [ "cranelift-codegen", "log", @@ -576,15 +576,15 @@ dependencies = [ [[package]] name = "cranelift-isle" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaf4dd3bc4bd4d594180a4db16b0b411a9c3a6757a4d13c5ede1177a0131fc80" +checksum = "f4a66ccad5782f15c80e9dd5af0df4acfe6e3eee98e8f7354a2e5c8ec3104bdd" [[package]] name = "cranelift-jit" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52bb8f1189dd81ed989b2cbd62371d67d35afb8977239d4a917a05b7e22e0692" +checksum = "dabad94c317b89a8300e9c48ee4715ffc0c1235ec94cf6bb71aa12511cb1afe8" dependencies = [ "anyhow", "cranelift-codegen", @@ -602,9 +602,9 @@ dependencies = [ [[package]] name = "cranelift-module" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecedd1fec01db9760f94dc2abca90975ef5e90746be5aa0041dbba99268ac23" +checksum = "7d4f782914c709b021e76475516358e64287d17e950113e051468ccf9ca8ff43" dependencies = [ "anyhow", "cranelift-codegen", @@ -613,9 +613,9 @@ dependencies = [ [[package]] name = "cranelift-native" -version = "0.105.0" +version = "0.105.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ddcb8ffe184e9aaf941387f9bcd016236fe1bedc29a84af0771d98443a71aea" +checksum = "285e80df1d9b79ded9775b285df68b920a277b84f88a7228d2f5bc31fcdc58eb" dependencies = [ "cranelift-codegen", "libc", @@ -3792,9 +3792,9 @@ checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasmtime-jit-icache-coherence" -version = "18.0.0" +version = "18.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c87783a70d3b7602834118f42e73ed8979f1b75a01e0fc4bf311cc6dc31f8fc" +checksum = "866634605089b4632b32226b54aa3670d72e1849f9fc425c7e50b3749c2e6df3" dependencies = [ "cfg-if", "libc", diff --git a/Cargo.toml b/Cargo.toml index ccc7b36d2..82bac54fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,9 +83,9 @@ native-tls = { version = "0.2.4", optional = true } warp = { version = "=0.3.5", features = ["tls"], optional = true } reqwest = { version = "0.11.18", optional = true } tokio = { version = "1.28.2", features = ["full"] } -cranelift = "0.105.0" -cranelift-jit = "0.105.0" -cranelift-module = "0.105.0" +cranelift = "0.105.3" +cranelift-jit = "0.105.3" +cranelift-module = "0.105.3" [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.10", features = ["js"] } diff --git a/build/instructions_template.rs b/build/instructions_template.rs index 1840341b1..604ad4f06 100644 --- a/build/instructions_template.rs +++ b/build/instructions_template.rs @@ -606,6 +606,8 @@ enum SystemClauseType { InferenceLimitExceeded, #[strum_discriminants(strum(props(Arity = "1", Name = "$argv")))] Argv, + #[strum_discriminants(strum(props(Arity = "3", Name = "$jit_compile")))] + JitCompile, REPL(REPLCodePtr), } @@ -1902,6 +1904,7 @@ fn generate_instruction_preface() -> TokenStream { &Instruction::CallAddNonCountedBacktracking | &Instruction::CallPopCount | &Instruction::CallArgv | + &Instruction::CallJitCompile | &Instruction::CallEd25519SignRaw | &Instruction::CallEd25519VerifyRaw | &Instruction::CallEd25519SeedToPublicKey => { @@ -2138,6 +2141,7 @@ fn generate_instruction_preface() -> TokenStream { &Instruction::ExecuteAddNonCountedBacktracking | &Instruction::ExecutePopCount | &Instruction::ExecuteArgv | + &Instruction::ExecuteJitCompile | &Instruction::ExecuteEd25519SignRaw | &Instruction::ExecuteEd25519VerifyRaw | &Instruction::ExecuteEd25519SeedToPublicKey => { diff --git a/src/lib/jit.pl b/src/lib/jit.pl new file mode 100644 index 000000000..d04b2766c --- /dev/null +++ b/src/lib/jit.pl @@ -0,0 +1,11 @@ +:- module(jit, [jit_compile/1]). + +jit_compile(Clause) :- + ( nonvar(Clause) -> + ( Clause = Name / Arity -> + '$jit_compile'(user, Name, Arity) + ; Clause = Module : (Name / Arity) -> + '$jit_compile'(Module, Name, Arity) + ) + ; throw(error(instantiation_error, jit_compile/1)) + ). diff --git a/src/machine/compile.rs b/src/machine/compile.rs index 6eaed8aad..7f967e8bd 100644 --- a/src/machine/compile.rs +++ b/src/machine/compile.rs @@ -1355,8 +1355,6 @@ impl<'a, LS: LoadState<'a>> Loader<'a, LS> { index_ptr, ); - dbg!(&code); - self.wam_prelude.code.extend(code); Ok(code_index) } diff --git a/src/machine/dispatch.rs b/src/machine/dispatch.rs index f3a73e6ce..06702b1f1 100644 --- a/src/machine/dispatch.rs +++ b/src/machine/dispatch.rs @@ -4164,6 +4164,14 @@ impl Machine { try_or_throw!(self.machine_st, self.argv()); step_or_fail!(self, self.machine_st.p = self.machine_st.cp); } + &Instruction::CallJitCompile => { + try_or_throw!(self.machine_st, self.jit_compile()); + step_or_fail!(self, self.machine_st.p += 1); + } + &Instruction::ExecuteJitCompile => { + try_or_throw!(self.machine_st, self.jit_compile()); + step_or_fail!(self, self.machine_st.p = self.machine_st.cp); + } &Instruction::CallCurrentTime => { self.current_time(); step_or_fail!(self, self.machine_st.p += 1); diff --git a/src/machine/jit.rs b/src/machine/jit.rs index d69bf1af3..0cb1814bb 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -14,18 +14,24 @@ struct CompileOutput { } #[derive(Debug, PartialEq)] -enum JitCompileError { +pub enum JitCompileError { UndefinedPredicate, InstructionNotImplemented, } -struct JitMachine { +pub struct JitMachine { modules: HashMap, machine_state: *const u8, registers: *const u8, write_literal_to_var: *const u8, } +impl std::fmt::Debug for JitMachine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JitMachine") + } +} + impl JitMachine { pub fn new(machine_st: &MachineState) -> Self { @@ -124,12 +130,13 @@ impl JitMachine { Ok(()) } - pub fn exec(&self, name: &str) { + pub fn exec(&self, name: &str) -> Result<(), ()> { if let Some(output) = self.modules.get(name) { let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn() -> ()>(output.code_ptr) }; code_ptr(); + Ok(()) } else { - panic!("Can't find function"); + Err(()) } } } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index f67315a54..a5454feae 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -40,6 +40,7 @@ use crate::machine::args::*; use crate::machine::compile::*; use crate::machine::copier::*; use crate::machine::heap::*; +use crate::machine::jit::*; use crate::machine::loader::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; @@ -81,6 +82,7 @@ pub struct Machine { #[cfg(feature = "ffi")] pub(super) foreign_function_table: ForeignFunctionTable, pub(super) rng: StdRng, + pub(super) jit_machine: JitMachine, } #[derive(Debug)] @@ -467,6 +469,8 @@ impl Machine { ), }; + let jit_machine = JitMachine::new(&machine_st); + let mut wam = Machine { machine_st, indices: IndexStore::new(), @@ -478,6 +482,7 @@ impl Machine { #[cfg(feature = "ffi")] foreign_function_table: Default::default(), rng: StdRng::from_entropy(), + jit_machine: jit_machine, }; let mut lib_path = current_dir(); @@ -1147,6 +1152,19 @@ impl Machine { fn try_execute(&mut self, name: Atom, arity: usize, idx: IndexPtr) -> CallResult { let compiled_tl_index = idx.p() as usize; + if name == atom!("a") { + self.machine_st.p = self.machine_st.cp; + self.machine_st.oip = 0; + self.machine_st.iip = 0; + self.machine_st.num_of_args = arity; + self.machine_st.b0 = self.machine_st.b; + } + + if let Ok(_) = self.jit_machine.exec(&format!("{}/{}", name.as_str(), arity)) { + println!("Executed JIT predicate"); + return Ok(()); + } + match idx.tag() { IndexPtrTag::DynamicUndefined => { self.machine_st.fail = true; diff --git a/src/machine/system_calls.rs b/src/machine/system_calls.rs index 73bf62f22..0645e05fa 100644 --- a/src/machine/system_calls.rs +++ b/src/machine/system_calls.rs @@ -19,6 +19,7 @@ use crate::machine; use crate::machine::code_walker::*; use crate::machine::copier::*; use crate::machine::heap::*; +use crate::machine::jit::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; use crate::machine::machine_state::*; @@ -4971,6 +4972,79 @@ impl Machine { Ok(()) } + // Tries to compile an existing predicate into native code + // For this to work: the predicate must be loaded, must use the subset of Prolog supported by the JIT + // and every call to a predicate must have been compiled previously + #[inline(always)] + pub(crate) fn jit_compile(&mut self) -> CallResult { + let module_name = cell_as_atom!(self.deref_register(1)); + let name = cell_as_atom!(self.deref_register(2)); + let arity = self.deref_register(3); + + let arity = match Number::try_from(arity) { + Ok(Number::Fixnum(n)) => n.get_num() as usize, + Ok(Number::Integer(n)) => { + let value: usize = (&*n).try_into().unwrap(); + value + } + _ => { + unreachable!() + } + }; + + let key = (name, arity); + + let first_idx = match module_name { + atom!("user") => self.indices.code_dir.get(&key), + _ => match self.indices.modules.get(&module_name) { + Some(module) => module.code_dir.get(&key), + None => { + let stub = functor_stub(key.0, key.1); + let err = self.machine_st.session_error(SessionError::from( + CompilationError::InvalidModuleResolution(module_name), + )); + + return Err(self.machine_st.error_form(err, stub)); + } + }, + }; + + let first_idx = match first_idx { + Some(idx) if idx.local().is_some() => { + if let Some(idx) = idx.local() { + idx + } else { + unreachable!() + } + } + _ => { + let stub = functor_stub(name, arity); + let err = self + .machine_st + .existence_error(ExistenceError::Procedure(name, arity)); + + return Err(self.machine_st.error_form(err, stub)); + } + }; + + let mut code = vec![]; + walk_code(&self.code, first_idx, |instr| code.push(instr.clone())); + + match self.jit_machine.compile(&format!("{}/{}", name.as_str(), arity), code) { + Err(JitCompileError::UndefinedPredicate) => { + eprintln!("jit_compiler: undefined_predicate"); + self.machine_st.fail = true; + } + Err(JitCompileError::InstructionNotImplemented) => { + eprintln!("jit_compiler: instruction not implemented"); + self.machine_st.fail = true; + } + _ => {} + } + + Ok(()) + } + #[inline(always)] pub(crate) fn current_time(&mut self) { let timestamp = self.systemtime_to_timestamp(SystemTime::now()); From 8b2785261a3137e71c42e8f2685cda10eda5d5dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Wed, 20 Mar 2024 20:07:36 +0100 Subject: [PATCH 04/28] First predicate able to execute! --- src/machine/jit.rs | 41 ++++++++++++++++++++++--------- src/machine/machine_state.rs | 1 + src/machine/machine_state_impl.rs | 5 ++-- src/machine/mod.rs | 13 +++------- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 0cb1814bb..88533e87a 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -21,9 +21,8 @@ pub enum JitCompileError { pub struct JitMachine { modules: HashMap, - machine_state: *const u8, - registers: *const u8, write_literal_to_var: *const u8, + deref: *const u8, } impl std::fmt::Debug for JitMachine { @@ -37,9 +36,8 @@ impl JitMachine { JitMachine { modules: HashMap::new(), - machine_state: machine_st as *const MachineState as *const u8, - registers: machine_st.registers.as_ptr() as *const u8, write_literal_to_var: MachineState::write_literal_to_var as *const u8, + deref: MachineState::deref as *const u8, } } @@ -53,18 +51,21 @@ impl JitMachine { builder.symbols(self.modules.iter().map(|(k, v)| (k,v.code_ptr))); let mut module = JITModule::new(builder); + let pointer_type = module.isa().pointer_type(); let mut ctx = module.make_context(); let mut func_ctx = FunctionBuilderContext::new(); let mut sig = module.make_signature(); - // Set arguments/returns - sig.call_conv = isa::CallConv::Tail; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); + sig.call_conv = isa::CallConv::SystemV; ctx.func.signature = sig.clone(); let mut func = module.declare_function(name, Linkage::Local, &sig).unwrap(); let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); let block = fn_builder.create_block(); + fn_builder.append_block_params_for_function_params(block); fn_builder.switch_to_block(block); for wam_instr in code { match wam_instr { @@ -89,12 +90,21 @@ impl JitMachine { } } Instruction::GetConstant(_, c, reg) => { - let reg_ptr = fn_builder.ins().iconst(types::I64, self.registers as i64); // TODO: call deref + let machine_state_value = fn_builder.block_params(block)[0]; + let reg_ptr = fn_builder.block_params(block)[1]; let reg_num = reg.reg_num(); let reg_value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((reg_num as i32)*8)); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let deref = fn_builder.ins().iconst(types::I64, self.deref as i64); + let deref_call = fn_builder.ins().call_indirect(sig_ref, deref, &[machine_state_value, reg_value]); + let reg_value = fn_builder.inst_results(deref_call)[0]; let c = unsafe { std::mem::transmute::(u64::from(c)) }; let c_value = fn_builder.ins().iconst(types::I64, c); - let machine_state_value = fn_builder.ins().iconst(types::I64, self.machine_state as i64); let mut sig = module.make_signature(); sig.call_conv = isa::CallConv::SystemV; sig.params.push(AbiParam::new(types::I64)); @@ -105,7 +115,7 @@ impl JitMachine { fn_builder.ins().call_indirect(sig_ref, write_literal_to_var, &[machine_state_value, reg_value, c_value]); } Instruction::PutConstant(_, c, reg) => { - let reg_ptr = fn_builder.ins().iconst(types::I64, self.registers as i64); + let reg_ptr = fn_builder.block_params(block)[1]; let reg_num = reg.reg_num(); let c = unsafe { std::mem::transmute::(u64::from(c)) }; let c_value = fn_builder.ins().iconst(types::I64, c); @@ -130,10 +140,17 @@ impl JitMachine { Ok(()) } - pub fn exec(&self, name: &str) -> Result<(), ()> { + pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { if let Some(output) = self.modules.get(name) { - let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn() -> ()>(output.code_ptr) }; - code_ptr(); + machine_st.p = machine_st.cp; + machine_st.oip = 0; + machine_st.iip = 0; + // machine_st.num_of_args = arity; + machine_st.num_of_args = 1; + machine_st.b0 = machine_st.b; + + let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn(*mut MachineState, *mut Registers) -> ()>(output.code_ptr) }; + code_ptr(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers); Ok(()) } else { Err(()) diff --git a/src/machine/machine_state.rs b/src/machine/machine_state.rs index 3763b2c09..b2ae875d4 100644 --- a/src/machine/machine_state.rs +++ b/src/machine/machine_state.rs @@ -21,6 +21,7 @@ use indexmap::IndexMap; use std::convert::TryFrom; use std::fmt; +use std::pin::*; use std::ops::{Index, IndexMut}; use std::sync::Arc; diff --git a/src/machine/machine_state_impl.rs b/src/machine/machine_state_impl.rs index 459db1d76..c7f478dd8 100644 --- a/src/machine/machine_state_impl.rs +++ b/src/machine/machine_state_impl.rs @@ -19,6 +19,7 @@ use indexmap::IndexSet; use std::cmp::Ordering; use std::convert::TryFrom; +use std::pin::*; impl MachineState { pub(crate) fn new() -> Self { @@ -79,7 +80,7 @@ impl MachineState { } #[inline] - pub fn deref(&self, mut addr: HeapCellValue) -> HeapCellValue { + pub extern "C" fn deref(&self, mut addr: HeapCellValue) -> HeapCellValue { loop { let value = self.store(addr); @@ -1012,7 +1013,7 @@ impl MachineState { } } - pub(super) fn write_literal_to_var(&mut self, deref_v: HeapCellValue, lit: HeapCellValue) { + pub(super) extern "C" fn write_literal_to_var(&mut self, deref_v: HeapCellValue, lit: HeapCellValue) { let store_v = self.store(deref_v); read_heap_cell!(lit, diff --git a/src/machine/mod.rs b/src/machine/mod.rs index a5454feae..b3510334e 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -61,6 +61,7 @@ use std::cmp::Ordering; use std::env; use std::io::Read; use std::path::PathBuf; +use std::pin::*; use std::sync::atomic::AtomicBool; use self::config::MachineConfig; @@ -1152,16 +1153,8 @@ impl Machine { fn try_execute(&mut self, name: Atom, arity: usize, idx: IndexPtr) -> CallResult { let compiled_tl_index = idx.p() as usize; - if name == atom!("a") { - self.machine_st.p = self.machine_st.cp; - self.machine_st.oip = 0; - self.machine_st.iip = 0; - self.machine_st.num_of_args = arity; - self.machine_st.b0 = self.machine_st.b; - } - - if let Ok(_) = self.jit_machine.exec(&format!("{}/{}", name.as_str(), arity)) { - println!("Executed JIT predicate"); + if let Ok(_) = self.jit_machine.exec(&format!("{}/{}", name.as_str(), arity), &mut self.machine_st) { + println!("jit_compiler: executed JIT predicate"); return Ok(()); } From 930629693c2ab336758762805d0b538155564cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sat, 23 Mar 2024 23:09:50 +0100 Subject: [PATCH 05/28] Adding trampoline --- src/machine/jit.rs | 76 +++++++++++++++++++++++++++++++++++----------- src/machine/mod.rs | 2 +- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 88533e87a..809ea6802 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -21,6 +21,7 @@ pub enum JitCompileError { pub struct JitMachine { modules: HashMap, + trampoline: extern "C" fn (*mut MachineState, *mut Registers, *const u8), write_literal_to_var: *const u8, deref: *const u8, } @@ -32,10 +33,51 @@ impl std::fmt::Debug for JitMachine { } impl JitMachine { - pub fn new(machine_st: &MachineState) -> Self { + pub fn new() -> Self { + // Build trampoline: from SysV to Tail + let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true")], cranelift_module::default_libcall_names()).unwrap(); + + let mut module = JITModule::new(builder); + let pointer_type = module.isa().pointer_type(); + let mut ctx = module.make_context(); + let mut func_ctx = FunctionBuilderContext::new(); + + let mut sig = module.make_signature(); + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); + sig.call_conv = isa::CallConv::SystemV; + ctx.func.signature = sig.clone(); + + let mut func = module.declare_function("$trampoline", Linkage::Local, &sig).unwrap(); + let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); + let block = fn_builder.create_block(); + fn_builder.append_block_params_for_function_params(block); + fn_builder.switch_to_block(block); + let machine_state = fn_builder.block_params(block)[0]; + let machine_registers = fn_builder.block_params(block)[1]; + let func_addr = fn_builder.block_params(block)[2]; + + let mut sig = module.make_signature(); + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); + sig.call_conv = isa::CallConv::Tail; + let sig_ref = fn_builder.import_signature(sig); + fn_builder.ins().call_indirect(sig_ref, func_addr, &[machine_state, machine_registers]); + fn_builder.ins().return_(&[]); + fn_builder.seal_all_blocks(); + fn_builder.finalize(); + + module.define_function(func, &mut ctx).unwrap(); + module.clear_context(&mut ctx); + + module.finalize_definitions().unwrap(); + + let code_ptr = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; JitMachine { modules: HashMap::new(), + trampoline: code_ptr, write_literal_to_var: MachineState::write_literal_to_var as *const u8, deref: MachineState::deref as *const u8, } @@ -58,7 +100,7 @@ impl JitMachine { let mut sig = module.make_signature(); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = isa::CallConv::Tail; ctx.func.signature = sig.clone(); let mut func = module.declare_function(name, Linkage::Local, &sig).unwrap(); @@ -149,8 +191,9 @@ impl JitMachine { machine_st.num_of_args = 1; machine_st.b0 = machine_st.b; - let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn(*mut MachineState, *mut Registers) -> ()>(output.code_ptr) }; - code_ptr(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers); + //let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn(*mut MachineState, *mut Registers) -> ()>(output.code_ptr) }; + //code_ptr(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers); + (self.trampoline)(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers, output.code_ptr); Ok(()) } else { Err(()) @@ -161,21 +204,21 @@ impl JitMachine { // basic. #[test] fn jit_test_proceed() { - let machine_st = MachineState::new(); + let mut machine_st = MachineState::new(); let code = vec![Instruction::Proceed]; let name = "basic/0"; - let mut jit = JitMachine::new(&machine_st); + let mut jit = JitMachine::new(); assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name); + jit.exec(name, &mut machine_st); } // basic. // simple :- basic. #[test] fn jit_test_execute_named() { - let machine_st = MachineState::new(); - let mut jit = JitMachine::new(&machine_st); + let mut machine_st = MachineState::new(); + let mut jit = JitMachine::new(); let code = vec![Instruction::Proceed]; let name = "basic/0"; assert_eq!(jit.compile(name, code), Ok(())); @@ -183,15 +226,15 @@ fn jit_test_execute_named() { let code = vec![Instruction::ExecuteNamed(0, atom!("basic"), CodeIndex::default(&mut Arena::new()))]; let name = "simple/0"; assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name); + jit.exec(name, &mut machine_st); } // a(5). // b :- a(5). #[test] fn jit_test_get_constant() { - let machine_st = MachineState::new(); - let mut jit = JitMachine::new(&machine_st); + let mut machine_st = MachineState::new(); + let mut jit = JitMachine::new(); let code = vec![Instruction::GetConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::Proceed]; let name = "a/1"; assert_eq!(jit.compile(name, code), Ok(())); @@ -199,7 +242,7 @@ fn jit_test_get_constant() { let code = vec![Instruction::PutConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut Arena::new()))]; let name = "b/0"; assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name); + jit.exec(name, &mut machine_st); assert_eq!(machine_st.fail, false); } @@ -207,9 +250,8 @@ fn jit_test_get_constant() { // b :- a(6). #[test] fn jit_test_get_constant_fail() { - let machine_st = MachineState::new(); - let machine_st = Box::pin(MachineState::new()); - let mut jit = JitMachine::new(&machine_st); + let mut machine_st = MachineState::new(); + let mut jit = JitMachine::new(); let code = vec![Instruction::GetConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::Proceed]; let name = "a/1"; assert_eq!(jit.compile(name, code), Ok(())); @@ -217,6 +259,6 @@ fn jit_test_get_constant_fail() { let code = vec![Instruction::PutConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(6)), RegType::Temp(1)), Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut Arena::new()))]; let name = "b/0"; assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name); + jit.exec(name, &mut machine_st); assert_eq!(machine_st.fail, true); } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index b3510334e..919613fbe 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -470,7 +470,7 @@ impl Machine { ), }; - let jit_machine = JitMachine::new(&machine_st); + let jit_machine = JitMachine::new(); let mut wam = Machine { machine_st, From 4cb231456a67a41d51c3bb587c6262cd0ee992e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sun, 24 Mar 2024 12:01:59 +0100 Subject: [PATCH 06/28] Fix ExecuteNamed --- src/machine/jit.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 809ea6802..762fe20f5 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -107,7 +107,7 @@ impl JitMachine { let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); let block = fn_builder.create_block(); - fn_builder.append_block_params_for_function_params(block); + fn_builder.append_block_params_for_function_params(block); fn_builder.switch_to_block(block); for wam_instr in code { match wam_instr { @@ -118,12 +118,15 @@ impl JitMachine { break; } Instruction::ExecuteNamed(arity, pred_name, ..) => { + let machine_state_value = fn_builder.block_params(block)[0]; + let reg_ptr = fn_builder.block_params(block)[1]; let mut callee_func_sig = module.make_signature(); callee_func_sig.call_conv = isa::CallConv::Tail; - // right now, all predicates have 0 arguments. In the future with shadow registers, we could improve this + callee_func_sig.params.push(AbiParam::new(pointer_type)); + callee_func_sig.params.push(AbiParam::new(pointer_type)); if let Ok(callee_func) = module.declare_function(&format!("{}/{}", pred_name.as_str(), arity), Linkage::Import, &callee_func_sig) { let func_ref = module.declare_func_in_func(callee_func, fn_builder.func); - fn_builder.ins().return_call(func_ref, &[]); + fn_builder.ins().return_call(func_ref, &[machine_state_value, reg_ptr]); fn_builder.seal_all_blocks(); fn_builder.finalize(); break; From cb80744c83da63b2ef649acb086e88801a7fa0a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 26 Mar 2024 23:19:33 +0100 Subject: [PATCH 07/28] More instructions --- src/machine/jit.rs | 120 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 762fe20f5..6f9b22135 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -22,8 +22,11 @@ pub enum JitCompileError { pub struct JitMachine { modules: HashMap, trampoline: extern "C" fn (*mut MachineState, *mut Registers, *const u8), + offset_interms: usize, write_literal_to_var: *const u8, deref: *const u8, + store: *const u8, + unify_fixnum: *const u8, } impl std::fmt::Debug for JitMachine { @@ -75,16 +78,26 @@ impl JitMachine { let code_ptr = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; + let machine_st = std::mem::MaybeUninit::uninit(); + let machine_st_ptr: *const MachineState = machine_st.as_ptr(); + let machine_st_ptr_u8 = machine_st_ptr as *const u8; + let offset_interms = unsafe { + let interms_ptr = std::ptr::addr_of!((*machine_st_ptr).interms) as *const u8; + interms_ptr.offset_from(machine_st_ptr_u8) as usize + }; + JitMachine { modules: HashMap::new(), trampoline: code_ptr, + offset_interms: offset_interms, write_literal_to_var: MachineState::write_literal_to_var as *const u8, deref: MachineState::deref as *const u8, + store: MachineState::store as *const u8, + unify_fixnum: MachineState::unify_fixnum as *const u8, } } // For now, one module = one predicate - // Functions must take N parameters (arity) // Access to MachineState via global pointer // MachineState Registers + ShadowRegisters?? // Use TAIL call convention @@ -109,6 +122,7 @@ impl JitMachine { let block = fn_builder.create_block(); fn_builder.append_block_params_for_function_params(block); fn_builder.switch_to_block(block); + // TODO: Manage failure for wam_instr in code { match wam_instr { Instruction::Proceed => { @@ -134,6 +148,7 @@ impl JitMachine { return Err(JitCompileError::UndefinedPredicate); } } + // TODO Manage RegType Instruction::GetConstant(_, c, reg) => { let machine_state_value = fn_builder.block_params(block)[0]; let reg_ptr = fn_builder.block_params(block)[1]; @@ -159,6 +174,17 @@ impl JitMachine { let write_literal_to_var = fn_builder.ins().iconst(types::I64, self.write_literal_to_var as i64); fn_builder.ins().call_indirect(sig_ref, write_literal_to_var, &[machine_state_value, reg_value, c_value]); } + Instruction::GetVariable(norm, arg) => { + let reg_ptr = fn_builder.block_params(block)[1]; + let value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((arg as i32)*8)); + match norm { + RegType::Temp(temp) => { + fn_builder.ins().store(MemFlags::new(), value, reg_ptr, Offset32::new((temp as i32)*8)); + } + _ => unimplemented!() + } + } + // TODO Manage RegType Instruction::PutConstant(_, c, reg) => { let reg_ptr = fn_builder.block_params(block)[1]; let reg_num = reg.reg_num(); @@ -166,6 +192,98 @@ impl JitMachine { let c_value = fn_builder.ins().iconst(types::I64, c); fn_builder.ins().store(MemFlags::new(), c_value, reg_ptr, Offset32::new((reg_num as i32)*8)); } + Instruction::PutValue(norm, arg) => { + let reg_ptr = fn_builder.block_params(block)[1]; + match norm { + RegType::Temp(temp) => { + let value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((temp as i32)*8)); + fn_builder.ins().store(MemFlags::new(), value, reg_ptr, Offset32::new((arg as i32)*8)); + } + _ => unimplemented!() + } + } + // TODO Fill more cases. Can we optimize the add in some cases to use the Cranelift add? + Instruction::Add(ref a1, ref a2, t) => { + /*let val_a = match a1 { + ArithmeticTerm::Number(n) => { + n + } + _ => unimplemented!() + }; + let val_b = match a2 { + ArithmeticTerm::Number(n) => { + n + } + _ => unimplemented!() + }; + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let add = fn_builder.ins().iconst(types::I64, self.add as i64); + let add_call = fn_builder.ins().call_indirect(sig_ref, add + + */ + } + // TOO SIMPLE, just for debugging + Instruction::ExecuteIs(r, at) => { + let machine_state = fn_builder.block_params(block)[0]; + let reg_ptr = fn_builder.block_params(block)[1]; + let n1 = match r { + RegType::Temp(temp) => { + fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((temp as i32)*8)) + } + _ => unimplemented!() + }; + + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let deref = fn_builder.ins().iconst(types::I64, self.deref as i64); + let deref_call = fn_builder.ins().call_indirect(sig_ref, deref, &[machine_state, n1]); + let n1 = fn_builder.inst_results(deref_call)[0]; + + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let store = fn_builder.ins().iconst(types::I64, self.store as i64); + let store_call = fn_builder.ins().call_indirect(sig_ref, store, &[machine_state, n1]); + let n1 = fn_builder.inst_results(store_call)[0]; + + let n = match at { + ArithmeticTerm::Number(n) => { + match n { + Number::Fixnum(fixnum) => { + fixnum + } + _ => unimplemented!() + } + } + _ => unimplemented!() + }; + + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let n = fn_builder.ins().iconst(types::I64, unsafe { std::mem::transmute::<_, i64>(n) }); + let unify_fixnum = fn_builder.ins().iconst(types::I64, self.unify_fixnum as i64); + fn_builder.ins().call_indirect(sig_ref, unify_fixnum, &[machine_state, n, n1]); + fn_builder.ins().return_(&[]); + fn_builder.seal_all_blocks(); + fn_builder.finalize(); + break; + } _ => { return Err(JitCompileError::InstructionNotImplemented); } From a5754d4e5d823e53e453a053f99780ffd174afd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Mon, 8 Apr 2024 23:43:42 +0200 Subject: [PATCH 08/28] Add instruction --- src/machine/arithmetic_ops.rs | 5 ++ src/machine/dispatch.rs | 13 ++++ src/machine/jit.rs | 117 ++++++++++++++++++++++++++-------- 3 files changed, 109 insertions(+), 26 deletions(-) diff --git a/src/machine/arithmetic_ops.rs b/src/machine/arithmetic_ops.rs index 3de47a604..959bb67e3 100644 --- a/src/machine/arithmetic_ops.rs +++ b/src/machine/arithmetic_ops.rs @@ -128,6 +128,11 @@ fn isize_gcd(n1: isize, n2: isize) -> Option { Some(n1 << shift as isize) } +// TODO: Error handling +pub extern "C" fn add_jit(lhs: Number, rhs: Number, arena: &mut Arena) -> Number { + add(lhs, rhs, &mut Arena::new()).unwrap() +} + pub(crate) fn add(lhs: Number, rhs: Number, arena: &mut Arena) -> Result { match (lhs, rhs) { (Number::Fixnum(n1), Number::Fixnum(n2)) => Ok( diff --git a/src/machine/dispatch.rs b/src/machine/dispatch.rs index 06702b1f1..183b442af 100644 --- a/src/machine/dispatch.rs +++ b/src/machine/dispatch.rs @@ -187,6 +187,19 @@ impl MachineState { Ok(()) } + pub extern "C" fn unify_num_jit(&mut self, n: Number, n1: HeapCellValue) { + match n { + Number::Fixnum(n) => self.unify_fixnum(n, n1), + Number::Float(n) => { + let n = float_alloc!(n.into_inner(), self.arena); + self.unify_f64(n, n1) + } + Number::Integer(n) => self.unify_big_int(n, n1), + Number::Rational(n) => self.unify_rational(n, n1), + + } + } + #[inline(always)] pub(crate) fn select_switch_on_term_index( &self, diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 6f9b22135..3e4feea31 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use crate::instructions::*; use crate::machine::*; +use crate::machine::arithmetic_ops::add_jit; use cranelift::prelude::*; use cranelift::prelude::codegen::ir::immediates::Offset32; @@ -23,10 +24,14 @@ pub struct JitMachine { modules: HashMap, trampoline: extern "C" fn (*mut MachineState, *mut Registers, *const u8), offset_interms: usize, + offset_arena: usize, write_literal_to_var: *const u8, deref: *const u8, store: *const u8, - unify_fixnum: *const u8, + unify_num: *const u8, + get_number: *const u8, + add: *const u8, + vec_as_ptr: *const u8, } impl std::fmt::Debug for JitMachine { @@ -38,7 +43,7 @@ impl std::fmt::Debug for JitMachine { impl JitMachine { pub fn new() -> Self { // Build trampoline: from SysV to Tail - let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true")], cranelift_module::default_libcall_names()).unwrap(); + let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true"), ("enable_llvm_abi_extensions", "1")], cranelift_module::default_libcall_names()).unwrap(); let mut module = JITModule::new(builder); let pointer_type = module.isa().pointer_type(); @@ -85,15 +90,23 @@ impl JitMachine { let interms_ptr = std::ptr::addr_of!((*machine_st_ptr).interms) as *const u8; interms_ptr.offset_from(machine_st_ptr_u8) as usize }; + let offset_arena = unsafe { + let arena_ptr = std::ptr::addr_of!((*machine_st_ptr).arena) as *const u8; + arena_ptr.offset_from(machine_st_ptr_u8) as usize + }; JitMachine { modules: HashMap::new(), trampoline: code_ptr, offset_interms: offset_interms, + offset_arena: offset_arena, write_literal_to_var: MachineState::write_literal_to_var as *const u8, deref: MachineState::deref as *const u8, store: MachineState::store as *const u8, - unify_fixnum: MachineState::unify_fixnum as *const u8, + unify_num: MachineState::unify_num_jit as *const u8, + get_number: MachineState::get_number as *const u8, + add: add_jit as *const u8, + vec_as_ptr: Vec::::as_ptr as *const u8, } } @@ -102,7 +115,7 @@ impl JitMachine { // MachineState Registers + ShadowRegisters?? // Use TAIL call convention pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError>{ - let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true")], cranelift_module::default_libcall_names()).unwrap(); + let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true"), ("enable_llvm_abi_extensions", "1")], cranelift_module::default_libcall_names()).unwrap(); builder.symbols(self.modules.iter().map(|(k, v)| (k,v.code_ptr))); let mut module = JITModule::new(builder); @@ -204,28 +217,71 @@ impl JitMachine { } // TODO Fill more cases. Can we optimize the add in some cases to use the Cranelift add? Instruction::Add(ref a1, ref a2, t) => { - /*let val_a = match a1 { - ArithmeticTerm::Number(n) => { - n + let machine_state = fn_builder.block_params(block)[0]; + let n1 = match a1 { + &ArithmeticTerm::Number(n) => { + let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; + let lo = fn_builder.ins().iconst(types::I64, n128 as i64); + let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); + fn_builder.ins().iconcat(lo, hi) } + &ArithmeticTerm::Interm(i) => { + let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); + let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); + let interms = fn_builder.inst_results(vec_ptr_call)[0]; + fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) + } _ => unimplemented!() }; - let val_b = match a2 { - ArithmeticTerm::Number(n) => { - n + let n2 = match a2 { + &ArithmeticTerm::Number(n) => { + let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; + let lo = fn_builder.ins().iconst(types::I64, n128 as i64); + let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); + fn_builder.ins().iconcat(lo, hi) } + &ArithmeticTerm::Interm(i) => { + let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); + let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); + let interms = fn_builder.inst_results(vec_ptr_call)[0]; + fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) + } _ => unimplemented!() }; let mut sig = module.make_signature(); sig.call_conv = isa::CallConv::SystemV; - sig.params.push(AbiParam::new(types::I64)); - sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I128)); + sig.params.push(AbiParam::new(types::I128)); sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(types::I128)); let sig_ref = fn_builder.import_signature(sig); - let add = fn_builder.ins().iconst(types::I64, self.add as i64); - let add_call = fn_builder.ins().call_indirect(sig_ref, add - - */ + let arena = fn_builder.ins().iadd_imm(machine_state, self.offset_arena as i64); + let add_jit = fn_builder.ins().iconst(types::I64, self.add as i64); + let add_jit_call = fn_builder.ins().call_indirect(sig_ref, add_jit, &[n1, n2, arena]); + let n3 = fn_builder.inst_results(add_jit_call)[0]; + + let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); + let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); + let interms = fn_builder.inst_results(vec_ptr_call)[0]; + fn_builder.ins().store(MemFlags::new(), n3, interms, Offset32::new((t as i32 - 1) * 16)); } // TOO SIMPLE, just for debugging Instruction::ExecuteIs(r, at) => { @@ -260,12 +316,22 @@ impl JitMachine { let n = match at { ArithmeticTerm::Number(n) => { - match n { - Number::Fixnum(fixnum) => { - fixnum - } - _ => unimplemented!() - } + let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; + let lo = fn_builder.ins().iconst(types::I64, n128 as i64); + let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); + fn_builder.ins().iconcat(lo, hi) + } + ArithmeticTerm::Interm(i) => { + let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); + let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); + let interms = fn_builder.inst_results(vec_ptr_call)[0]; + fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) } _ => unimplemented!() }; @@ -273,12 +339,11 @@ impl JitMachine { let mut sig = module.make_signature(); sig.call_conv = isa::CallConv::SystemV; sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I128)); sig.params.push(AbiParam::new(types::I64)); let sig_ref = fn_builder.import_signature(sig); - let n = fn_builder.ins().iconst(types::I64, unsafe { std::mem::transmute::<_, i64>(n) }); - let unify_fixnum = fn_builder.ins().iconst(types::I64, self.unify_fixnum as i64); - fn_builder.ins().call_indirect(sig_ref, unify_fixnum, &[machine_state, n, n1]); + let unify_num = fn_builder.ins().iconst(types::I64, self.unify_num as i64); + fn_builder.ins().call_indirect(sig_ref, unify_num, &[machine_state, n, n1]); fn_builder.ins().return_(&[]); fn_builder.seal_all_blocks(); fn_builder.finalize(); From b668c26f0765c03bb288d76d62b6c21799329a87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Mon, 15 Apr 2024 08:06:21 +0200 Subject: [PATCH 09/28] Add tests nd jit_store_deref_reg helper --- src/machine/jit.rs | 67 +++++++++++++++------------ src/tests/jit_test.pl | 10 ++++ tests/scryer/cli/src_tests/jit.stderr | 0 tests/scryer/cli/src_tests/jit.stdout | 1 + tests/scryer/cli/src_tests/jit.toml | 1 + 5 files changed, 50 insertions(+), 29 deletions(-) create mode 100644 src/tests/jit_test.pl create mode 100644 tests/scryer/cli/src_tests/jit.stderr create mode 100644 tests/scryer/cli/src_tests/jit.stdout create mode 100644 tests/scryer/cli/src_tests/jit.toml diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 3e4feea31..fed3f4264 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -6,6 +6,7 @@ use crate::machine::arithmetic_ops::add_jit; use cranelift::prelude::*; use cranelift::prelude::codegen::ir::immediates::Offset32; +use cranelift::prelude::codegen::ir::entities::Value; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Linkage, Module}; @@ -187,6 +188,7 @@ impl JitMachine { let write_literal_to_var = fn_builder.ins().iconst(types::I64, self.write_literal_to_var as i64); fn_builder.ins().call_indirect(sig_ref, write_literal_to_var, &[machine_state_value, reg_value, c_value]); } + // TODO: Manage RegType Instruction::GetVariable(norm, arg) => { let reg_ptr = fn_builder.block_params(block)[1]; let value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((arg as i32)*8)); @@ -287,33 +289,7 @@ impl JitMachine { Instruction::ExecuteIs(r, at) => { let machine_state = fn_builder.block_params(block)[0]; let reg_ptr = fn_builder.block_params(block)[1]; - let n1 = match r { - RegType::Temp(temp) => { - fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((temp as i32)*8)) - } - _ => unimplemented!() - }; - - let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let deref = fn_builder.ins().iconst(types::I64, self.deref as i64); - let deref_call = fn_builder.ins().call_indirect(sig_ref, deref, &[machine_state, n1]); - let n1 = fn_builder.inst_results(deref_call)[0]; - - let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let store = fn_builder.ins().iconst(types::I64, self.store as i64); - let store_call = fn_builder.ins().call_indirect(sig_ref, store, &[machine_state, n1]); - let n1 = fn_builder.inst_results(store_call)[0]; - + let n1 = self.jit_store_deref_reg(&module, machine_state, reg_ptr, &mut fn_builder, r); let n = match at { ArithmeticTerm::Number(n) => { let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; @@ -333,6 +309,9 @@ impl JitMachine { let interms = fn_builder.inst_results(vec_ptr_call)[0]; fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) } + /*ArithmeticTerm::Reg(reg_type) => { + // TODO + }*/ _ => unimplemented!() }; @@ -367,6 +346,38 @@ impl JitMachine { Ok(()) } + + fn jit_store_deref_reg(&self, module: &JITModule, machine_state: Value, reg_ptr: Value, fn_builder: &mut FunctionBuilder, reg: RegType) -> Value { + let pointer_type = module.isa().pointer_type(); + let system_call_conv = module.isa().default_call_conv(); + let n1 = match reg { + RegType::Temp(temp) => { + fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((temp as i32)*8)) + } + _ => unimplemented!() // TODO + }; + + let mut sig = module.make_signature(); + sig.call_conv = system_call_conv; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let deref = fn_builder.ins().iconst(types::I64, self.deref as i64); + let deref_call = fn_builder.ins().call_indirect(sig_ref, deref, &[machine_state, n1]); + let n1 = fn_builder.inst_results(deref_call)[0]; + + let mut sig = module.make_signature(); + sig.call_conv = system_call_conv; + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + let sig_ref = fn_builder.import_signature(sig); + let store = fn_builder.ins().iconst(types::I64, self.store as i64); + let store_call = fn_builder.ins().call_indirect(sig_ref, store, &[machine_state, n1]); + fn_builder.inst_results(store_call)[0] + } + pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { if let Some(output) = self.modules.get(name) { @@ -377,8 +388,6 @@ impl JitMachine { machine_st.num_of_args = 1; machine_st.b0 = machine_st.b; - //let code_ptr = unsafe { std::mem::transmute::<_, extern "C" fn(*mut MachineState, *mut Registers) -> ()>(output.code_ptr) }; - //code_ptr(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers); (self.trampoline)(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers, output.code_ptr); Ok(()) } else { diff --git a/src/tests/jit_test.pl b/src/tests/jit_test.pl new file mode 100644 index 000000000..2e7acab51 --- /dev/null +++ b/src/tests/jit_test.pl @@ -0,0 +1,10 @@ +:- use_module(library(jit)). + +a(X) :- b(X). +b(X) :- X is 1 + 1 + 1 + 1. + +test :- + jit_compile(b/1), + jit_compile(a/1), + a(X), + X = 4. diff --git a/tests/scryer/cli/src_tests/jit.stderr b/tests/scryer/cli/src_tests/jit.stderr new file mode 100644 index 000000000..e69de29bb diff --git a/tests/scryer/cli/src_tests/jit.stdout b/tests/scryer/cli/src_tests/jit.stdout new file mode 100644 index 000000000..332f8828f --- /dev/null +++ b/tests/scryer/cli/src_tests/jit.stdout @@ -0,0 +1 @@ +jit_compiler: executed JIT predicate diff --git a/tests/scryer/cli/src_tests/jit.toml b/tests/scryer/cli/src_tests/jit.toml new file mode 100644 index 000000000..d053abe9a --- /dev/null +++ b/tests/scryer/cli/src_tests/jit.toml @@ -0,0 +1 @@ +args = ["-f", "--no-add-history", "src/tests/jit_test.pl", "-f", "-g", "test"] From 9dcbd23b10ed19299fc377009c2015948ea730ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Mon, 15 Apr 2024 23:08:05 +0200 Subject: [PATCH 10/28] Improve add and execute to handle registers --- src/arithmetic.rs | 6 ++ src/machine/jit.rs | 114 ++++++++++---------------- src/tests/jit_test.pl | 5 +- tests/scryer/cli/src_tests/jit.stdout | 1 + 4 files changed, 56 insertions(+), 70 deletions(-) diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 799dfbd65..19910273e 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -674,6 +674,12 @@ impl Ord for Number { } } +impl Number { + pub extern "C" fn jit_try_from(value: HeapCellValue) -> Number { + Number::try_from(value).unwrap() + } +} + impl TryFrom for Number { type Error = (); diff --git a/src/machine/jit.rs b/src/machine/jit.rs index fed3f4264..6e29a77d1 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -33,6 +33,7 @@ pub struct JitMachine { get_number: *const u8, add: *const u8, vec_as_ptr: *const u8, + number_try_from: *const u8, } impl std::fmt::Debug for JitMachine { @@ -108,6 +109,7 @@ impl JitMachine { get_number: MachineState::get_number as *const u8, add: add_jit as *const u8, vec_as_ptr: Vec::::as_ptr as *const u8, + number_try_from: Number::jit_try_from as *const u8, } } @@ -218,50 +220,11 @@ impl JitMachine { } } // TODO Fill more cases. Can we optimize the add in some cases to use the Cranelift add? - Instruction::Add(ref a1, ref a2, t) => { - let machine_state = fn_builder.block_params(block)[0]; - let n1 = match a1 { - &ArithmeticTerm::Number(n) => { - let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; - let lo = fn_builder.ins().iconst(types::I64, n128 as i64); - let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); - fn_builder.ins().iconcat(lo, hi) - } - &ArithmeticTerm::Interm(i) => { - let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); - let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; - sig.params.push(AbiParam::new(pointer_type)); - sig.returns.push(AbiParam::new(pointer_type)); - let sig_ref = fn_builder.import_signature(sig); - let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); - let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); - let interms = fn_builder.inst_results(vec_ptr_call)[0]; - fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) - } - _ => unimplemented!() - }; - let n2 = match a2 { - &ArithmeticTerm::Number(n) => { - let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; - let lo = fn_builder.ins().iconst(types::I64, n128 as i64); - let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); - fn_builder.ins().iconcat(lo, hi) - } - &ArithmeticTerm::Interm(i) => { - let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); - let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; - sig.params.push(AbiParam::new(pointer_type)); - sig.returns.push(AbiParam::new(pointer_type)); - let sig_ref = fn_builder.import_signature(sig); - let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); - let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); - let interms = fn_builder.inst_results(vec_ptr_call)[0]; - fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) - } - _ => unimplemented!() - }; + Instruction::Add(a1, a2, t) => { + let machine_state = fn_builder.block_params(block)[0]; + let reg_ptr = fn_builder.block_params(block)[1]; + let n1 = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, a1); + let n2 = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, a2); let mut sig = module.make_signature(); sig.call_conv = isa::CallConv::SystemV; sig.params.push(AbiParam::new(types::I128)); @@ -285,35 +248,11 @@ impl JitMachine { let interms = fn_builder.inst_results(vec_ptr_call)[0]; fn_builder.ins().store(MemFlags::new(), n3, interms, Offset32::new((t as i32 - 1) * 16)); } - // TOO SIMPLE, just for debugging Instruction::ExecuteIs(r, at) => { let machine_state = fn_builder.block_params(block)[0]; let reg_ptr = fn_builder.block_params(block)[1]; let n1 = self.jit_store_deref_reg(&module, machine_state, reg_ptr, &mut fn_builder, r); - let n = match at { - ArithmeticTerm::Number(n) => { - let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; - let lo = fn_builder.ins().iconst(types::I64, n128 as i64); - let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); - fn_builder.ins().iconcat(lo, hi) - } - ArithmeticTerm::Interm(i) => { - let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); - let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; - sig.params.push(AbiParam::new(pointer_type)); - sig.returns.push(AbiParam::new(pointer_type)); - let sig_ref = fn_builder.import_signature(sig); - let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); - let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); - let interms = fn_builder.inst_results(vec_ptr_call)[0]; - fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) - } - /*ArithmeticTerm::Reg(reg_type) => { - // TODO - }*/ - _ => unimplemented!() - }; + let n = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, at); let mut sig = module.make_signature(); sig.call_conv = isa::CallConv::SystemV; @@ -378,6 +317,43 @@ impl JitMachine { fn_builder.inst_results(store_call)[0] } + fn jit_get_number(&self, module: &JITModule, machine_state: Value, reg_ptr: Value, fn_builder: &mut FunctionBuilder, at: ArithmeticTerm) -> Value { + let pointer_type = module.isa().pointer_type(); + let system_call_conv = module.isa().default_call_conv(); + + match at { + ArithmeticTerm::Number(n) => { + let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; + let lo = fn_builder.ins().iconst(types::I64, n128 as i64); + let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); + fn_builder.ins().iconcat(lo, hi) + } + ArithmeticTerm::Interm(i) => { + let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); + let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); + let interms = fn_builder.inst_results(vec_ptr_call)[0]; + fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) + } + ArithmeticTerm::Reg(reg_type) => { + let value = self.jit_store_deref_reg(&module, machine_state, reg_ptr, fn_builder, reg_type); + let mut sig = module.make_signature(); + sig.call_conv = isa::CallConv::SystemV; + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I128)); + let sig_ref = fn_builder.import_signature(sig); + let number_try_from = fn_builder.ins().iconst(types::I64, self.number_try_from as i64); + let number_try_from_call = fn_builder.ins().call_indirect(sig_ref, number_try_from, &[value]); + fn_builder.inst_results(number_try_from_call)[0] + } + } + } + pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { if let Some(output) = self.modules.get(name) { diff --git a/src/tests/jit_test.pl b/src/tests/jit_test.pl index 2e7acab51..a2dfc344a 100644 --- a/src/tests/jit_test.pl +++ b/src/tests/jit_test.pl @@ -2,9 +2,12 @@ a(X) :- b(X). b(X) :- X is 1 + 1 + 1 + 1. +c(X, Y) :- X is Y + 1. test :- jit_compile(b/1), jit_compile(a/1), a(X), - X = 4. + X = 4, + jit_compile(c/2), + c(2, 1). diff --git a/tests/scryer/cli/src_tests/jit.stdout b/tests/scryer/cli/src_tests/jit.stdout index 332f8828f..3ef8d8148 100644 --- a/tests/scryer/cli/src_tests/jit.stdout +++ b/tests/scryer/cli/src_tests/jit.stdout @@ -1 +1,2 @@ jit_compiler: executed JIT predicate +jit_compiler: executed JIT predicate From 0f2376b861ead21cbd788bb2b32fadc6bef2e434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Fri, 31 May 2024 23:23:22 +0200 Subject: [PATCH 11/28] Move JIT code to feature "jit" --- .github/workflows/jit.yml | 18 ++++++++++++++++++ Cargo.toml | 7 ++++--- src/machine/mod.rs | 19 +++++++++++-------- src/machine/system_calls.rs | 8 ++++++++ tests/scryer/cli/src_tests/jit.stdout | 2 -- 5 files changed, 41 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/jit.yml diff --git a/.github/workflows/jit.yml b/.github/workflows/jit.yml new file mode 100644 index 000000000..67ec50710 --- /dev/null +++ b/.github/workflows/jit.yml @@ -0,0 +1,18 @@ +name: Test JIT + +on: + pull_request: + workflow_dispatch: + +jobs: + test-jit: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - name: Setup Rust + uses: ./.github/actions/setup-rust + with: + rust-version: nightly + targets: x86_64-unknown-linux-gnu + - name: Test with JIT + run: cargo +nightly test --features jit diff --git a/Cargo.toml b/Cargo.toml index 82bac54fe..20f7db3da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ tls = ["dep:native-tls"] http = ["dep:warp", "dep:reqwest"] rust_beta_channel = [] crypto-full = [] +jit = ["dep:cranelift", "dep:cranelift-jit", "dep:cranelift-module"] [build-dependencies] indexmap = "1.0.2" @@ -83,9 +84,9 @@ native-tls = { version = "0.2.4", optional = true } warp = { version = "=0.3.5", features = ["tls"], optional = true } reqwest = { version = "0.11.18", optional = true } tokio = { version = "1.28.2", features = ["full"] } -cranelift = "0.105.3" -cranelift-jit = "0.105.3" -cranelift-module = "0.105.3" +cranelift = { version = "0.105.3", optional = true } +cranelift-jit = { version = "0.105.3", optional = true } +cranelift-module = { version = "0.105.3", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.10", features = ["js"] } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index 919613fbe..2a7e343b6 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -12,6 +12,7 @@ pub mod disjuncts; pub mod dispatch; pub mod gc; pub mod heap; +#[cfg(feature = "jit")] pub mod jit; pub mod lib_machine; pub mod load_state; @@ -40,6 +41,7 @@ use crate::machine::args::*; use crate::machine::compile::*; use crate::machine::copier::*; use crate::machine::heap::*; +#[cfg(feature = "jit")] use crate::machine::jit::*; use crate::machine::loader::*; use crate::machine::machine_errors::*; @@ -61,7 +63,6 @@ use std::cmp::Ordering; use std::env; use std::io::Read; use std::path::PathBuf; -use std::pin::*; use std::sync::atomic::AtomicBool; use self::config::MachineConfig; @@ -83,6 +84,7 @@ pub struct Machine { #[cfg(feature = "ffi")] pub(super) foreign_function_table: ForeignFunctionTable, pub(super) rng: StdRng, + #[cfg(feature = "jit")] pub(super) jit_machine: JitMachine, } @@ -470,8 +472,6 @@ impl Machine { ), }; - let jit_machine = JitMachine::new(); - let mut wam = Machine { machine_st, indices: IndexStore::new(), @@ -483,7 +483,8 @@ impl Machine { #[cfg(feature = "ffi")] foreign_function_table: Default::default(), rng: StdRng::from_entropy(), - jit_machine: jit_machine, + #[cfg(feature = "jit")] + jit_machine: JitMachine::new(), }; let mut lib_path = current_dir(); @@ -1148,14 +1149,16 @@ impl Machine { } - #[inline(always)] fn try_execute(&mut self, name: Atom, arity: usize, idx: IndexPtr) -> CallResult { let compiled_tl_index = idx.p() as usize; - if let Ok(_) = self.jit_machine.exec(&format!("{}/{}", name.as_str(), arity), &mut self.machine_st) { - println!("jit_compiler: executed JIT predicate"); - return Ok(()); + #[cfg(feature = "jit")] + { + if let Ok(_) = self.jit_machine.exec(&format!("{}/{}", name.as_str(), arity), &mut self.machine_st) { + // println!("jit_compiler: executed JIT predicate"); + return Ok(()); + } } match idx.tag() { diff --git a/src/machine/system_calls.rs b/src/machine/system_calls.rs index 0645e05fa..7ff5353af 100644 --- a/src/machine/system_calls.rs +++ b/src/machine/system_calls.rs @@ -19,6 +19,7 @@ use crate::machine; use crate::machine::code_walker::*; use crate::machine::copier::*; use crate::machine::heap::*; +#[cfg(feature = "jit")] use crate::machine::jit::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; @@ -4972,9 +4973,16 @@ impl Machine { Ok(()) } + #[cfg(not(feature = "jit"))] + #[inline(always)] + pub(crate) fn jit_compile(&mut self) -> CallResult { + Ok(()) + } + // Tries to compile an existing predicate into native code // For this to work: the predicate must be loaded, must use the subset of Prolog supported by the JIT // and every call to a predicate must have been compiled previously + #[cfg(feature = "jit")] #[inline(always)] pub(crate) fn jit_compile(&mut self) -> CallResult { let module_name = cell_as_atom!(self.deref_register(1)); diff --git a/tests/scryer/cli/src_tests/jit.stdout b/tests/scryer/cli/src_tests/jit.stdout index 3ef8d8148..e69de29bb 100644 --- a/tests/scryer/cli/src_tests/jit.stdout +++ b/tests/scryer/cli/src_tests/jit.stdout @@ -1,2 +0,0 @@ -jit_compiler: executed JIT predicate -jit_compiler: executed JIT predicate From eed33712cf9eb157cb9ab1bf835e59d4c199bc01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sat, 1 Jun 2024 00:23:07 +0200 Subject: [PATCH 12/28] Add SetConstant --- Cargo.lock | 54 ++++++++++++++++++++++++---------------------- Cargo.toml | 8 +++---- src/machine/jit.rs | 41 ++++++++++++++++++++++++++--------- 3 files changed, 63 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c7eb71793..15dba0fba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -494,9 +494,9 @@ dependencies = [ [[package]] name = "cranelift" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d8964676f9a0165edbdb7650344acb73d4decd590685f0964b6109b3691fff8" +checksum = "9c23938094c74206f455ac334e8166a2f3e78cfa604933c97925f159621e6061" dependencies = [ "cranelift-codegen", "cranelift-frontend", @@ -504,18 +504,18 @@ dependencies = [ [[package]] name = "cranelift-bforest" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16d5521e2abca66bbb1ddeecbb6f6965c79160352ae1579b39f8c86183895c24" +checksum = "29daf137addc15da6bab6eae2c4a11e274b1d270bf2759508e62f6145e863ef6" dependencies = [ "cranelift-entity", ] [[package]] name = "cranelift-codegen" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef40a4338a47506e832ac3e53f7f1375bc59351f049a8379ff736dd02565bd95" +checksum = "de619867d5de4c644b7fd9904d6e3295269c93d8a71013df796ab338681222d4" dependencies = [ "bumpalo", "cranelift-bforest", @@ -528,45 +528,46 @@ dependencies = [ "hashbrown 0.14.3", "log", "regalloc2", + "rustc-hash", "smallvec", "target-lexicon", ] [[package]] name = "cranelift-codegen-meta" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d24cd5d85985c070f73dfca07521d09086362d1590105ba44b0932bf33513b61" +checksum = "29f5cf277490037d8dae9513d35e0ee8134670ae4a964a5ed5b198d4249d7c10" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0584c4363e3aa0a3c7cb98a778fbd5326a3709f117849a727da081d4051726c" +checksum = "8c3e22ecad1123343a3c09ac6ecc532bb5c184b6fcb7888df0ea953727f79924" [[package]] name = "cranelift-control" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f25ecede098c6553fdba362a8e4c9ecb8d40138363bff47f9712db75be7f0571" +checksum = "53ca3ec6d30bce84ccf59c81fead4d16381a3ef0ef75e8403bc1e7385980da09" dependencies = [ "arbitrary", ] [[package]] name = "cranelift-entity" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea081a42f25dc4c5b248b87efdd87dcd3842a1050a37524ec5391e6172058cb" +checksum = "7eabb8d36b0ca8906bec93c78ea516741cac2d7e6b266fa7b0ffddcc09004990" [[package]] name = "cranelift-frontend" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9796e712f5af797e247784f7518e6b0a83a8907d73d51526982d86ecb3a58b68" +checksum = "44b42630229e49a8cfcae90bdc43c8c4c08f7a7aa4618b67f79265cd2f996dd2" dependencies = [ "cranelift-codegen", "log", @@ -576,15 +577,15 @@ dependencies = [ [[package]] name = "cranelift-isle" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a66ccad5782f15c80e9dd5af0df4acfe6e3eee98e8f7354a2e5c8ec3104bdd" +checksum = "918d1e36361805dfe0b6cdfd5a5ffdb5d03fa796170c5717d2727cbe623b93a0" [[package]] name = "cranelift-jit" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dabad94c317b89a8300e9c48ee4715ffc0c1235ec94cf6bb71aa12511cb1afe8" +checksum = "399bad3a10a12c475f812dd889b5a5d296eaacc22d4c9652cdefb25dfb9ae56c" dependencies = [ "anyhow", "cranelift-codegen", @@ -602,9 +603,9 @@ dependencies = [ [[package]] name = "cranelift-module" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d4f782914c709b021e76475516358e64287d17e950113e051468ccf9ca8ff43" +checksum = "1ed398b2df2a409a189dbe7950d365b50aa61380cec1f53d9891d4ad8374edb7" dependencies = [ "anyhow", "cranelift-codegen", @@ -613,9 +614,9 @@ dependencies = [ [[package]] name = "cranelift-native" -version = "0.105.3" +version = "0.108.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "285e80df1d9b79ded9775b285df68b920a277b84f88a7228d2f5bc31fcdc58eb" +checksum = "75aea85a0d7e1800b14ce9d3f53adf8ad4d1ee8a9e23b0269bdc50285e93b9b3" dependencies = [ "cranelift-codegen", "libc", @@ -3792,10 +3793,11 @@ checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasmtime-jit-icache-coherence" -version = "18.0.3" +version = "21.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "866634605089b4632b32226b54aa3670d72e1849f9fc425c7e50b3749c2e6df3" +checksum = "afe088f9b56bb353adaf837bf7e10f1c2e1676719dd5be4cac8e37f2ba1ee5bc" dependencies = [ + "anyhow", "cfg-if", "libc", "windows-sys 0.52.0", diff --git a/Cargo.toml b/Cargo.toml index 20f7db3da..3748a5488 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rust-version = "1.70" crate-type = ["cdylib", "rlib", "staticlib"] [features] -default = ["ffi", "repl", "hostname", "tls", "http", "crypto-full"] +default = ["ffi", "repl", "hostname", "tls", "http", "crypto-full", "jit"] ffi = ["dep:libffi"] repl = ["dep:crossterm", "dep:ctrlc", "dep:rustyline"] hostname = ["dep:hostname"] @@ -84,9 +84,9 @@ native-tls = { version = "0.2.4", optional = true } warp = { version = "=0.3.5", features = ["tls"], optional = true } reqwest = { version = "0.11.18", optional = true } tokio = { version = "1.28.2", features = ["full"] } -cranelift = { version = "0.105.3", optional = true } -cranelift-jit = { version = "0.105.3", optional = true } -cranelift-module = { version = "0.105.3", optional = true } +cranelift = { version = "0.108.1", optional = true } +cranelift-jit = { version = "0.108.1", optional = true } +cranelift-module = { version = "0.108.1", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.10", features = ["js"] } diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 6e29a77d1..61bc341dd 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -26,6 +26,7 @@ pub struct JitMachine { trampoline: extern "C" fn (*mut MachineState, *mut Registers, *const u8), offset_interms: usize, offset_arena: usize, + offset_heap: usize, write_literal_to_var: *const u8, deref: *const u8, store: *const u8, @@ -33,6 +34,7 @@ pub struct JitMachine { get_number: *const u8, add: *const u8, vec_as_ptr: *const u8, + vec_push: *const u8, number_try_from: *const u8, } @@ -49,6 +51,7 @@ impl JitMachine { let mut module = JITModule::new(builder); let pointer_type = module.isa().pointer_type(); + let call_conv = module.isa().default_call_conv(); let mut ctx = module.make_context(); let mut func_ctx = FunctionBuilderContext::new(); @@ -56,7 +59,7 @@ impl JitMachine { sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; ctx.func.signature = sig.clone(); let mut func = module.declare_function("$trampoline", Linkage::Local, &sig).unwrap(); @@ -96,12 +99,16 @@ impl JitMachine { let arena_ptr = std::ptr::addr_of!((*machine_st_ptr).arena) as *const u8; arena_ptr.offset_from(machine_st_ptr_u8) as usize }; - + let offset_heap = unsafe { + let heap_ptr = std::ptr::addr_of!((*machine_st_ptr).heap) as *const u8; + heap_ptr.offset_from(machine_st_ptr_u8) as usize + }; JitMachine { modules: HashMap::new(), trampoline: code_ptr, offset_interms: offset_interms, offset_arena: offset_arena, + offset_heap: offset_heap, write_literal_to_var: MachineState::write_literal_to_var as *const u8, deref: MachineState::deref as *const u8, store: MachineState::store as *const u8, @@ -109,6 +116,7 @@ impl JitMachine { get_number: MachineState::get_number as *const u8, add: add_jit as *const u8, vec_as_ptr: Vec::::as_ptr as *const u8, + vec_push: Vec::::push as *const u8, number_try_from: Number::jit_try_from as *const u8, } } @@ -123,6 +131,7 @@ impl JitMachine { let mut module = JITModule::new(builder); let pointer_type = module.isa().pointer_type(); + let call_conv = module.isa().default_call_conv(); let mut ctx = module.make_context(); let mut func_ctx = FunctionBuilderContext::new(); @@ -171,7 +180,7 @@ impl JitMachine { let reg_num = reg.reg_num(); let reg_value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((reg_num as i32)*8)); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(types::I64)); sig.returns.push(AbiParam::new(types::I64)); @@ -182,7 +191,7 @@ impl JitMachine { let c = unsafe { std::mem::transmute::(u64::from(c)) }; let c_value = fn_builder.ins().iconst(types::I64, c); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); @@ -219,6 +228,18 @@ impl JitMachine { _ => unimplemented!() } } + Instruction::SetConstant(c) => { + let machine_st = fn_builder.block_params(block)[0]; + let mut sig = module.make_signature(); + sig.call_conv = call_conv; + sig.params.push(AbiParam::new(pointer_type)); + let sig_ref = fn_builder.import_signature(sig); + let heap = fn_builder.ins().iadd_imm(machine_st, self.offset_heap as i64); + let vec_push = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + let c = unsafe { std::mem::transmute::(u64::from(c)) }; + let c_value = fn_builder.ins().iconst(types::I64, c); + fn_builder.ins().call_indirect(sig_ref, vec_push, &[heap, c_value]); + } // TODO Fill more cases. Can we optimize the add in some cases to use the Cranelift add? Instruction::Add(a1, a2, t) => { let machine_state = fn_builder.block_params(block)[0]; @@ -226,7 +247,7 @@ impl JitMachine { let n1 = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, a1); let n2 = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, a2); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(types::I128)); sig.params.push(AbiParam::new(types::I128)); sig.params.push(AbiParam::new(pointer_type)); @@ -239,7 +260,7 @@ impl JitMachine { let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(pointer_type)); sig.returns.push(AbiParam::new(pointer_type)); let sig_ref = fn_builder.import_signature(sig); @@ -255,7 +276,7 @@ impl JitMachine { let n = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, at); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(types::I128)); sig.params.push(AbiParam::new(types::I64)); @@ -319,7 +340,7 @@ impl JitMachine { fn jit_get_number(&self, module: &JITModule, machine_state: Value, reg_ptr: Value, fn_builder: &mut FunctionBuilder, at: ArithmeticTerm) -> Value { let pointer_type = module.isa().pointer_type(); - let system_call_conv = module.isa().default_call_conv(); + let call_conv = module.isa().default_call_conv(); match at { ArithmeticTerm::Number(n) => { @@ -331,7 +352,7 @@ impl JitMachine { ArithmeticTerm::Interm(i) => { let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(pointer_type)); sig.returns.push(AbiParam::new(pointer_type)); let sig_ref = fn_builder.import_signature(sig); @@ -343,7 +364,7 @@ impl JitMachine { ArithmeticTerm::Reg(reg_type) => { let value = self.jit_store_deref_reg(&module, machine_state, reg_ptr, fn_builder, reg_type); let mut sig = module.make_signature(); - sig.call_conv = isa::CallConv::SystemV; + sig.call_conv = call_conv; sig.params.push(AbiParam::new(types::I64)); sig.returns.push(AbiParam::new(types::I128)); let sig_ref = fn_builder.import_signature(sig); From 056877cea9074616ea02b35b7b525ceb1393d358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Thu, 6 Jun 2024 22:10:04 +0200 Subject: [PATCH 13/28] Print code --- src/machine/jit.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 61bc341dd..404df2d2a 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -294,16 +294,15 @@ impl JitMachine { } } module.define_function(func, &mut ctx).unwrap(); + println!("{}", ctx.func.display()); module.clear_context(&mut ctx); module.finalize_definitions().unwrap(); - let code_ptr = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; self.modules.insert(name.to_string(), CompileOutput { module, code_ptr }); - Ok(()) } From 0569ef82a576e681b12cede8159d3a745c08b259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Mon, 17 Jun 2024 18:03:39 +0200 Subject: [PATCH 14/28] Add jit2 skeleton --- Cargo.lock | 1 + Cargo.toml | 5 +- src/machine/jit2.rs | 134 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 src/machine/jit2.rs diff --git a/Cargo.lock b/Cargo.lock index 15dba0fba..553cbcb6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2765,6 +2765,7 @@ dependencies = [ "console_error_panic_hook", "cpu-time", "cranelift", + "cranelift-codegen", "cranelift-jit", "cranelift-module", "criterion", diff --git a/Cargo.toml b/Cargo.toml index 3748a5488..a4320bd37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rust-version = "1.70" crate-type = ["cdylib", "rlib", "staticlib"] [features] -default = ["ffi", "repl", "hostname", "tls", "http", "crypto-full", "jit"] +default = ["ffi", "repl", "hostname", "tls", "http", "crypto-full"] ffi = ["dep:libffi"] repl = ["dep:crossterm", "dep:ctrlc", "dep:rustyline"] hostname = ["dep:hostname"] @@ -24,7 +24,7 @@ tls = ["dep:native-tls"] http = ["dep:warp", "dep:reqwest"] rust_beta_channel = [] crypto-full = [] -jit = ["dep:cranelift", "dep:cranelift-jit", "dep:cranelift-module"] +jit = ["dep:cranelift", "dep:cranelift-jit", "dep:cranelift-module", "dep:cranelift-codegen"] [build-dependencies] indexmap = "1.0.2" @@ -87,6 +87,7 @@ tokio = { version = "1.28.2", features = ["full"] } cranelift = { version = "0.108.1", optional = true } cranelift-jit = { version = "0.108.1", optional = true } cranelift-module = { version = "0.108.1", optional = true } +cranelift-codegen = { version = "0.108.1", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.10", features = ["js"] } diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs new file mode 100644 index 000000000..83fd918d7 --- /dev/null +++ b/src/machine/jit2.rs @@ -0,0 +1,134 @@ +// This is another implementation of the JIT compiler +// In this implementation we do not share data between the interpreter and the compiled code +// We do copies before and after the execution + +use crate::instructions::*; +use crate::machine::*; +use crate::parser::ast::*; + +use cranelift::prelude::*; +use cranelift_jit::{JITBuilder, JITModule}; +use cranelift_module::{Linkage, Module}; +use cranelift_codegen::Context; +use cranelift::prelude::codegen::ir::immediates::Offset32; + +#[derive(Debug, PartialEq)] +pub enum JitCompileError { + UndefinedPredicate, + InstructionNotImplemented, +} + + +pub struct JitMachine { + trampolines: Vec<*const u8>, + module: JITModule, + ctx: Context, + func_ctx: FunctionBuilderContext, +} + +impl std::fmt::Debug for JitMachine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JitMachine") + } +} + + +impl JitMachine { + pub fn new() -> Self { + let builder = JITBuilder::with_flags(&[ + ("preserve_frame_pointers", "true"), + ("enable_llvm_abi_extensions", "1")], + cranelift_module::default_libcall_names() + ).unwrap(); + let mut module = JITModule::new(builder); + let pointer_type = module.isa().pointer_type(); + let call_conv = module.isa().default_call_conv(); + + let mut ctx = module.make_context(); + let mut func_ctx = FunctionBuilderContext::new(); + + let mut sig = module.make_signature(); + sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); + sig.call_conv = call_conv; + + let mut trampolines = vec![]; + + // Should be MAX_ARITY + for n in 0..4 { + ctx.func.signature = sig.clone(); + let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); + let block = fn_builder.create_block(); + fn_builder.append_block_params_for_function_params(block); + fn_builder.switch_to_block(block); + let func_addr = fn_builder.block_params(block)[0]; + let registers = fn_builder.block_params(block)[1]; + let mut jump_sig = module.make_signature(); + jump_sig.call_conv = isa::CallConv::Tail; + let mut params = vec![]; + for i in 1..n+1 { + jump_sig.params.push(AbiParam::new(types::I64)); + let reg_value = fn_builder.ins().load(types::I64, MemFlags::new(), registers, Offset32::new((i as i32)*8)); + params.push(reg_value); + } + let jump_sig_ref = fn_builder.import_signature(jump_sig); + fn_builder.ins().call_indirect(jump_sig_ref, func_addr, ¶ms); + fn_builder.ins().return_(&[]); + fn_builder.seal_block(block); + fn_builder.finalize(); + + let func = module.declare_function(&format!("$trampoline/{}", n), Linkage::Local, &sig).unwrap(); + module.define_function(func, &mut ctx).unwrap(); + println!("{}", ctx.func.display()); + module.finalize_definitions().unwrap(); + module.clear_context(&mut ctx); + let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; + trampolines.push(code_ptr); + } + + + JitMachine { + trampolines, + module, + ctx, + func_ctx, + } + } + + // TODO: Compile taking into account arity + pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError> { + let mut fn_builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx); + let block = fn_builder.create_block(); + fn_builder.append_block_params_for_function_params(block); + fn_builder.switch_to_block(block); + + for wam_instr in code { + match wam_instr { + Instruction::Proceed => { + fn_builder.ins().return_(&[]); + break; + }, + _ => { + fn_builder.finalize(); + self.module.clear_context(&mut self.ctx); + return Err(JitCompileError::InstructionNotImplemented); + } + } + } + fn_builder.seal_all_blocks(); + fn_builder.finalize(); + + let mut sig = self.module.make_signature(); + sig.call_conv = isa::CallConv::Tail; + + let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); + self.module.define_function(func, &mut self.ctx).unwrap(); + println!("{}", self.ctx.func.display()); + self.module.clear_context(&mut self.ctx); + Ok(()) + } + + pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { + Ok(()) + } +} From e8cbd2ed297d3f680e4841c37fe2e6ad10ac8f35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Mon, 17 Jun 2024 18:22:50 +0200 Subject: [PATCH 15/28] Fix errors for CI --- src/machine/jit.rs | 2 +- src/machine/machine_state.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/machine/jit.rs b/src/machine/jit.rs index 404df2d2a..4a0a87fbe 100644 --- a/src/machine/jit.rs +++ b/src/machine/jit.rs @@ -294,7 +294,7 @@ impl JitMachine { } } module.define_function(func, &mut ctx).unwrap(); - println!("{}", ctx.func.display()); + // println!("{}", ctx.func.display()); module.clear_context(&mut ctx); module.finalize_definitions().unwrap(); diff --git a/src/machine/machine_state.rs b/src/machine/machine_state.rs index b2ae875d4..3763b2c09 100644 --- a/src/machine/machine_state.rs +++ b/src/machine/machine_state.rs @@ -21,7 +21,6 @@ use indexmap::IndexMap; use std::convert::TryFrom; use std::fmt; -use std::pin::*; use std::ops::{Index, IndexMut}; use std::sync::Arc; From 3dc92f2d83f86df0e433bb500262c1392154295c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 2 Jul 2024 20:59:15 +0200 Subject: [PATCH 16/28] Add more instructions to JIT2 --- src/machine/jit2.rs | 243 ++++++++++++++++++++++++++++++++++-- src/machine/mod.rs | 4 +- src/machine/system_calls.rs | 4 +- 3 files changed, 239 insertions(+), 12 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 83fd918d7..018d6c00c 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -11,6 +11,9 @@ use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Linkage, Module}; use cranelift_codegen::Context; use cranelift::prelude::codegen::ir::immediates::Offset32; +use cranelift::prelude::codegen::ir::entities::Value; + +use std::ops::Index; #[derive(Debug, PartialEq)] pub enum JitCompileError { @@ -24,6 +27,12 @@ pub struct JitMachine { module: JITModule, ctx: Context, func_ctx: FunctionBuilderContext, + heap_as_ptr: *const u8, + heap_as_ptr_sig: Signature, + heap_push: *const u8, + heap_push_sig: Signature, + heap_len: *const u8, + heap_len_sig: Signature, } impl std::fmt::Debug for JitMachine { @@ -85,27 +94,247 @@ impl JitMachine { let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; trampolines.push(code_ptr); } - - + let heap_as_ptr = Vec::::as_ptr as *const u8; + let mut heap_as_ptr_sig = module.make_signature(); + heap_as_ptr_sig.params.push(AbiParam::new(pointer_type)); + heap_as_ptr_sig.returns.push(AbiParam::new(pointer_type)); + let heap_push = Vec::::push as *const u8; + let mut heap_push_sig = module.make_signature(); + heap_push_sig.params.push(AbiParam::new(pointer_type)); + heap_push_sig.params.push(AbiParam::new(types::I64)); + let heap_len = Vec::::len as *const u8; + let mut heap_len_sig = module.make_signature(); + heap_len_sig.params.push(AbiParam::new(pointer_type)); + heap_len_sig.returns.push(AbiParam::new(types::I64)); JitMachine { trampolines, module, ctx, func_ctx, + heap_as_ptr, + heap_as_ptr_sig, + heap_push, + heap_push_sig, + heap_len, + heap_len_sig, } } - // TODO: Compile taking into account arity - pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError> { + pub fn compile(&mut self, name: &str, arity: usize, code: Code) -> Result<(), JitCompileError> { + let mut sig = self.module.make_signature(); + sig.params.push(AbiParam::new(types::I64)); + for _ in 1..=arity { + sig.params.push(AbiParam::new(types::I64)); + sig.returns.push(AbiParam::new(types::I64)); + } + sig.call_conv = isa::CallConv::Tail; + self.ctx.func.signature = sig.clone(); + let mut fn_builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx); let block = fn_builder.create_block(); fn_builder.append_block_params_for_function_params(block); fn_builder.switch_to_block(block); + fn_builder.seal_block(block); + + let heap = fn_builder.block_params(block)[0]; + let mode = Variable::new(0); + fn_builder.declare_var(mode, types::I8); + let s = Variable::new(1); + fn_builder.declare_var(s, types::I64); + let fail = Variable::new(2); + fn_builder.declare_var(fail, types::I8); + + let mut registers = vec![]; + for i in 1..=arity { + let reg = fn_builder.block_params(block)[i]; + registers.push(reg); + } + + macro_rules! heap_len { + () => { + {let sig_ref = fn_builder.import_signature(self.heap_len_sig.clone()); + let heap_len_fn = fn_builder.ins().iconst(types::I64, self.heap_len as i64); + let call_heap_len = fn_builder.ins().call_indirect(sig_ref, heap_len_fn, &[heap]); + let heap_len = fn_builder.inst_results(call_heap_len)[0]; + heap_len} + } + } + + macro_rules! heap_as_ptr { + () => { + { + let sig_ref = fn_builder.import_signature(self.heap_as_ptr_sig.clone()); + let heap_as_ptr_fn = fn_builder.ins().iconst(types::I64, self.heap_as_ptr as i64); + let call_heap_as_ptr = fn_builder.ins().call_indirect(sig_ref, heap_as_ptr_fn, &[heap]); + let heap_ptr = fn_builder.inst_results(call_heap_as_ptr)[0]; + heap_ptr + } + } + } + + macro_rules! store { + ($x:expr) => { + { + let merge_block = fn_builder.create_block(); + fn_builder.append_block_param(merge_block, types::I64); + let is_var_block = fn_builder.create_block(); + fn_builder.append_block_param(is_var_block, types::I64); + let is_not_var_block = fn_builder.create_block(); + fn_builder.append_block_param(is_not_var_block, types::I64); + let tag = fn_builder.ins().band_imm($x, 64); + let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); + fn_builder.ins().brif(is_var, is_var_block, &[$x], is_not_var_block, &[$x]); + // is_var + fn_builder.switch_to_block(is_var_block); + fn_builder.seal_block(is_var_block); + let param = fn_builder.block_params(is_var_block)[0]; + let idx = fn_builder.ins().ushr_imm(param, 8); + let heap_ptr = heap_as_ptr!(); + let idx_ptr = fn_builder.ins().imul_imm(idx, 8); + let idx_ptr = fn_builder.ins().iadd(heap_ptr, idx_ptr); + let heap_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx_ptr, Offset32::new(0)); + fn_builder.ins().jump(merge_block, &[heap_value]); + // is_not_var + fn_builder.switch_to_block(is_not_var_block); + fn_builder.seal_block(is_not_var_block); + let param = fn_builder.block_params(is_not_var_block)[0]; + fn_builder.ins().jump(merge_block, &[param]); + // merge + fn_builder.switch_to_block(merge_block); + fn_builder.seal_block(merge_block); + fn_builder.block_params(merge_block)[0] + } + } + } + + macro_rules! deref { + ($x:expr) => { + { + let exit_block = fn_builder.create_block(); + fn_builder.append_block_param(exit_block, types::I64); + let loop_block = fn_builder.create_block(); + fn_builder.append_block_param(loop_block, types::I64); + fn_builder.ins().jump(loop_block, &[$x]); + fn_builder.switch_to_block(loop_block); + let addr = fn_builder.block_params(loop_block)[0]; + let value = store!(addr); + // check if is var + let tag = fn_builder.ins().band_imm(value, 64); + let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); + let not_equal = fn_builder.ins().icmp(IntCC::NotEqual, value, addr); + let check = fn_builder.ins().band(is_var, not_equal); + fn_builder.ins().brif(check, loop_block, &[value], exit_block, &[value]); + // exit + fn_builder.seal_block(loop_block); + fn_builder.seal_block(exit_block); + fn_builder.switch_to_block(exit_block); + fn_builder.block_params(exit_block)[0] + + } + } + } for wam_instr in code { match wam_instr { + // TODO Missing RegType Perm + Instruction::PutStructure(name, arity, reg) => { + let atom_cell = atom_as_cell!(name, arity); + let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); + let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); + let heap_len = heap_len!(); + let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); + let str_cell = fn_builder.ins().bor(heap_len_shift, str_cell); + match reg { + RegType::Temp(x) => { + registers[x] = str_cell; + } + _ => unimplemented!() + } + } + // TODO Missing RegType Perm + Instruction::SetVariable(reg) => { + let heap_loc_cell = heap_loc_as_cell!(0); + let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); + let heap_len = heap_len!(); + let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); + let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); + match reg { + RegType::Temp(x) => { + registers[x] = heap_loc_cell; + } + _ => unimplemented!() + } + } + // TODO: Missing RegType Perm + Instruction::SetValue(reg) => { + let value = match reg { + RegType::Temp(x) => { + registers[x] + }, + _ => unimplemented!() + }; + let value = store!(value); + + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); + } + // TODO: Missing RegType Perm. Let's suppose Mode is local to each predicate + // TODO: Missing support for PStr and CStr + Instruction::UnifyVariable(reg) => { + let read_block = fn_builder.create_block(); + let write_block = fn_builder.create_block(); + let exit_block = fn_builder.create_block(); + let mode_value = fn_builder.use_var(mode); + fn_builder.ins().brif(mode_value, write_block, &[], read_block, &[]); + fn_builder.seal_block(read_block); + fn_builder.seal_block(write_block); + // read + fn_builder.switch_to_block(read_block); + let heap_ptr = heap_as_ptr!(); + let s_value = fn_builder.use_var(s); + let idx = fn_builder.ins().iadd(heap_ptr, s_value); + let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); + let value = deref!(value); + match reg { + RegType::Temp(x) => { + registers[x] = value; + }, + _ => unimplemented!() + } + let sum_s = fn_builder.ins().iadd_imm(s_value, 8); + fn_builder.def_var(s, sum_s); + fn_builder.ins().jump(exit_block, &[]); + // write (equal to SetVariable) + fn_builder.switch_to_block(write_block); + let heap_loc_cell = heap_loc_as_cell!(0); + let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); + let heap_len = heap_len!(); + let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); + let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); + match reg { + RegType::Temp(x) => { + registers[x] = heap_loc_cell; + } + _ => unimplemented!() + } + fn_builder.ins().jump(exit_block, &[]); + // exit + fn_builder.switch_to_block(exit_block); + fn_builder.seal_block(exit_block); + + } Instruction::Proceed => { - fn_builder.ins().return_(&[]); + fn_builder.ins().return_(®isters); break; }, _ => { @@ -118,8 +347,6 @@ impl JitMachine { fn_builder.seal_all_blocks(); fn_builder.finalize(); - let mut sig = self.module.make_signature(); - sig.call_conv = isa::CallConv::Tail; let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); self.module.define_function(func, &mut self.ctx).unwrap(); @@ -129,6 +356,6 @@ impl JitMachine { } pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { - Ok(()) + Err(()) } } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index 2a7e343b6..78e83e1d9 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -13,7 +13,7 @@ pub mod dispatch; pub mod gc; pub mod heap; #[cfg(feature = "jit")] -pub mod jit; +pub mod jit2; pub mod lib_machine; pub mod load_state; pub mod machine_errors; @@ -42,7 +42,7 @@ use crate::machine::compile::*; use crate::machine::copier::*; use crate::machine::heap::*; #[cfg(feature = "jit")] -use crate::machine::jit::*; +use crate::machine::jit2::*; use crate::machine::loader::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; diff --git a/src/machine/system_calls.rs b/src/machine/system_calls.rs index 7ff5353af..518aaa73a 100644 --- a/src/machine/system_calls.rs +++ b/src/machine/system_calls.rs @@ -20,7 +20,7 @@ use crate::machine::code_walker::*; use crate::machine::copier::*; use crate::machine::heap::*; #[cfg(feature = "jit")] -use crate::machine::jit::*; +use crate::machine::jit2::*; use crate::machine::machine_errors::*; use crate::machine::machine_indices::*; use crate::machine::machine_state::*; @@ -5038,7 +5038,7 @@ impl Machine { let mut code = vec![]; walk_code(&self.code, first_idx, |instr| code.push(instr.clone())); - match self.jit_machine.compile(&format!("{}/{}", name.as_str(), arity), code) { + match self.jit_machine.compile(&name.as_str(), arity, code) { Err(JitCompileError::UndefinedPredicate) => { eprintln!("jit_compiler: undefined_predicate"); self.machine_st.fail = true; From 094d8eb61a01b93cfc3c2c19c85ab2ae06170279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sat, 6 Jul 2024 13:25:03 +0200 Subject: [PATCH 17/28] Implement basic GetConstant to test integration --- src/machine/jit2.rs | 127 +++++++++++++++++++++++++++++++++++++++----- src/machine/mod.rs | 2 +- 2 files changed, 115 insertions(+), 14 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 018d6c00c..56a44ddb1 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -11,9 +11,8 @@ use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Linkage, Module}; use cranelift_codegen::Context; use cranelift::prelude::codegen::ir::immediates::Offset32; -use cranelift::prelude::codegen::ir::entities::Value; -use std::ops::Index; +use std::collections::HashMap; #[derive(Debug, PartialEq)] pub enum JitCompileError { @@ -33,6 +32,7 @@ pub struct JitMachine { heap_push_sig: Signature, heap_len: *const u8, heap_len_sig: Signature, + predicates: HashMap<(String, usize), *const u8>, } impl std::fmt::Debug for JitMachine { @@ -44,11 +44,12 @@ impl std::fmt::Debug for JitMachine { impl JitMachine { pub fn new() -> Self { - let builder = JITBuilder::with_flags(&[ + let mut builder = JITBuilder::with_flags(&[ ("preserve_frame_pointers", "true"), ("enable_llvm_abi_extensions", "1")], cranelift_module::default_libcall_names() ).unwrap(); + builder.symbol("print_func", print_syscall as *const u8); let mut module = JITModule::new(builder); let pointer_type = module.isa().pointer_type(); let call_conv = module.isa().default_call_conv(); @@ -59,6 +60,7 @@ impl JitMachine { let mut sig = module.make_signature(); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); sig.call_conv = call_conv; let mut trampolines = vec![]; @@ -72,16 +74,24 @@ impl JitMachine { fn_builder.switch_to_block(block); let func_addr = fn_builder.block_params(block)[0]; let registers = fn_builder.block_params(block)[1]; + let heap = fn_builder.block_params(block)[2]; let mut jump_sig = module.make_signature(); jump_sig.call_conv = isa::CallConv::Tail; + jump_sig.params.push(AbiParam::new(types::I64)); let mut params = vec![]; + params.push(heap); for i in 1..n+1 { jump_sig.params.push(AbiParam::new(types::I64)); - let reg_value = fn_builder.ins().load(types::I64, MemFlags::new(), registers, Offset32::new((i as i32)*8)); + jump_sig.returns.push(AbiParam::new(types::I64)); + let reg_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), registers, Offset32::new((i as i32)*8)); params.push(reg_value); } let jump_sig_ref = fn_builder.import_signature(jump_sig); - fn_builder.ins().call_indirect(jump_sig_ref, func_addr, ¶ms); + let jump_call = fn_builder.ins().call_indirect(jump_sig_ref, func_addr, ¶ms); + for i in 0..n { + let reg_value = fn_builder.inst_results(jump_call)[i]; + fn_builder.ins().store(MemFlags::trusted(), reg_value, registers, Offset32::new(((i as i32) + 1) * 8)); + } fn_builder.ins().return_(&[]); fn_builder.seal_block(block); fn_builder.finalize(); @@ -106,6 +116,8 @@ impl JitMachine { let mut heap_len_sig = module.make_signature(); heap_len_sig.params.push(AbiParam::new(pointer_type)); heap_len_sig.returns.push(AbiParam::new(types::I64)); + + let predicates = HashMap::new(); JitMachine { trampolines, module, @@ -117,10 +129,17 @@ impl JitMachine { heap_push_sig, heap_len, heap_len_sig, + predicates } } pub fn compile(&mut self, name: &str, arity: usize, code: Code) -> Result<(), JitCompileError> { + let mut print_func_sig = self.module.make_signature(); + print_func_sig.params.push(AbiParam::new(types::I64)); + let print_func = self.module + .declare_function("print_func", Linkage::Import, &print_func_sig) + .unwrap(); + let mut sig = self.module.make_signature(); sig.params.push(AbiParam::new(types::I64)); for _ in 1..=arity { @@ -129,6 +148,7 @@ impl JitMachine { } sig.call_conv = isa::CallConv::Tail; self.ctx.func.signature = sig.clone(); + self.ctx.set_disasm(true); let mut fn_builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx); let block = fn_builder.create_block(); @@ -150,6 +170,15 @@ impl JitMachine { registers.push(reg); } + macro_rules! print_rt { + ($x:expr) => { + { + let print_func_fn = self.module.declare_func_in_func(print_func, &mut fn_builder.func); + fn_builder.ins().call(print_func_fn, &[$x]); + } + } + } + macro_rules! heap_len { () => { {let sig_ref = fn_builder.import_signature(self.heap_len_sig.clone()); @@ -249,7 +278,7 @@ impl JitMachine { let str_cell = fn_builder.ins().bor(heap_len_shift, str_cell); match reg { RegType::Temp(x) => { - registers[x] = str_cell; + registers[x-1] = str_cell; } _ => unimplemented!() } @@ -266,7 +295,7 @@ impl JitMachine { fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); match reg { RegType::Temp(x) => { - registers[x] = heap_loc_cell; + registers[x-1] = heap_loc_cell; } _ => unimplemented!() } @@ -275,7 +304,7 @@ impl JitMachine { Instruction::SetValue(reg) => { let value = match reg { RegType::Temp(x) => { - registers[x] + registers[x-1] }, _ => unimplemented!() }; @@ -284,6 +313,41 @@ impl JitMachine { let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); + } + // TODO: Missing RegType Perm + Instruction::GetStructure(_lvl, name, arity, reg) => { + /*let xi = match reg { + RegType::Temp(x) => { + registers[x] + } + _ => unimplemented!() + }; + let deref = deref!(xi); + let store = store!(deref); + + let var_block = fn_builder.create_block(); + let str_block = fn_builder.create_block(); + let nostr_block = fn_builder.create_block(); + let other_block = fn_builder.create_block(); + let exit_block = fn_builder.create_block(); + + let tag = fn_builder.ins().band_imm(store, 64); + let is_str = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Str as i64); + fn_builder.ins().brif(is_str, str_block, &[], nostr_block, &[]); + fn_builder.seal_block(str_block); + fn_builder.seal_block(nostr_block); + // str block + fn_builder.switch_to_block(str_block); + let a = fn_builder.ins().ushr_imm(store, 8); + let heap_ptr = heap_as_ptr!(); + let a = fn_builder.ins().imul_imm(a, 8); + let idx = fn_builder.ins().iadd(heap_ptr, a); + let atom = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0));*/ + // TODO: Check atom values + // TODO: Continue + + + } // TODO: Missing RegType Perm. Let's suppose Mode is local to each predicate // TODO: Missing support for PStr and CStr @@ -304,7 +368,7 @@ impl JitMachine { let value = deref!(value); match reg { RegType::Temp(x) => { - registers[x] = value; + registers[x-1] = value; }, _ => unimplemented!() } @@ -323,7 +387,7 @@ impl JitMachine { fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); match reg { RegType::Temp(x) => { - registers[x] = heap_loc_cell; + registers[x-1] = heap_loc_cell; } _ => unimplemented!() } @@ -333,6 +397,27 @@ impl JitMachine { fn_builder.seal_block(exit_block); } + // TODO: Manage RegType Perm + // TODO: manage NonVar cases + // TODO: Manage failure + Instruction::GetConstant(_, c, reg) => { + let value = match reg { + RegType::Temp(x) => { + registers[x-1] + } + _ => unimplemented!() + }; + //let value = deref!(value); + //let value = store!(value); + // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) + let c = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(c.into_bytes())); + // Let's suppose STORE[addr] is REF + let heap_ptr = heap_as_ptr!(); + let idx = fn_builder.ins().ishl_imm(value, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + fn_builder.ins().store(MemFlags::new(), c, idx, Offset32::new(0)); + } Instruction::Proceed => { fn_builder.ins().return_(®isters); break; @@ -347,15 +432,31 @@ impl JitMachine { fn_builder.seal_all_blocks(); fn_builder.finalize(); - let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); self.module.define_function(func, &mut self.ctx).unwrap(); println!("{}", self.ctx.func.display()); + self.module.finalize_definitions().unwrap(); + let code_ptr = self.module.get_finalized_function(func); + self.predicates.insert((name.to_string(), arity), code_ptr); + println!("{}", self.ctx.compiled_code().unwrap().vcode.clone().unwrap()); self.module.clear_context(&mut self.ctx); Ok(()) } - pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { - Err(()) + pub fn exec(&self, name: &str, arity: usize, machine_st: &mut MachineState) -> Result<(), ()> { + let Some(predicate) = self.predicates.get(&(name.to_string(), arity)) else { + return Err(()); + }; + let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec) = unsafe { std::mem::transmute(self.trampolines[arity])}; + let registers = machine_st.registers.as_ptr() as *mut Registers; + let heap = &machine_st.heap as *const Vec; + trampoline(*predicate, registers, heap); + machine_st.p = machine_st.cp; + Ok(()) } } + + +fn print_syscall(value: i64) { + println!("{}", value); +} diff --git a/src/machine/mod.rs b/src/machine/mod.rs index 78e83e1d9..f1597688c 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -1155,7 +1155,7 @@ impl Machine { #[cfg(feature = "jit")] { - if let Ok(_) = self.jit_machine.exec(&format!("{}/{}", name.as_str(), arity), &mut self.machine_st) { + if let Ok(_) = self.jit_machine.exec(&name.as_str(), arity, &mut self.machine_st) { // println!("jit_compiler: executed JIT predicate"); return Ok(()); } From eb2fe461a0717a8c9b35dbd170b988ba8b46c6c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 9 Jul 2024 19:57:58 +0200 Subject: [PATCH 18/28] Cleanup JIT1 --- src/arithmetic.rs | 6 ------ src/machine/arithmetic_ops.rs | 5 ----- src/machine/dispatch.rs | 13 ------------- src/machine/machine_state_impl.rs | 5 ++--- 4 files changed, 2 insertions(+), 27 deletions(-) diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 19910273e..799dfbd65 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -674,12 +674,6 @@ impl Ord for Number { } } -impl Number { - pub extern "C" fn jit_try_from(value: HeapCellValue) -> Number { - Number::try_from(value).unwrap() - } -} - impl TryFrom for Number { type Error = (); diff --git a/src/machine/arithmetic_ops.rs b/src/machine/arithmetic_ops.rs index 959bb67e3..3de47a604 100644 --- a/src/machine/arithmetic_ops.rs +++ b/src/machine/arithmetic_ops.rs @@ -128,11 +128,6 @@ fn isize_gcd(n1: isize, n2: isize) -> Option { Some(n1 << shift as isize) } -// TODO: Error handling -pub extern "C" fn add_jit(lhs: Number, rhs: Number, arena: &mut Arena) -> Number { - add(lhs, rhs, &mut Arena::new()).unwrap() -} - pub(crate) fn add(lhs: Number, rhs: Number, arena: &mut Arena) -> Result { match (lhs, rhs) { (Number::Fixnum(n1), Number::Fixnum(n2)) => Ok( diff --git a/src/machine/dispatch.rs b/src/machine/dispatch.rs index 183b442af..06702b1f1 100644 --- a/src/machine/dispatch.rs +++ b/src/machine/dispatch.rs @@ -187,19 +187,6 @@ impl MachineState { Ok(()) } - pub extern "C" fn unify_num_jit(&mut self, n: Number, n1: HeapCellValue) { - match n { - Number::Fixnum(n) => self.unify_fixnum(n, n1), - Number::Float(n) => { - let n = float_alloc!(n.into_inner(), self.arena); - self.unify_f64(n, n1) - } - Number::Integer(n) => self.unify_big_int(n, n1), - Number::Rational(n) => self.unify_rational(n, n1), - - } - } - #[inline(always)] pub(crate) fn select_switch_on_term_index( &self, diff --git a/src/machine/machine_state_impl.rs b/src/machine/machine_state_impl.rs index c7f478dd8..459db1d76 100644 --- a/src/machine/machine_state_impl.rs +++ b/src/machine/machine_state_impl.rs @@ -19,7 +19,6 @@ use indexmap::IndexSet; use std::cmp::Ordering; use std::convert::TryFrom; -use std::pin::*; impl MachineState { pub(crate) fn new() -> Self { @@ -80,7 +79,7 @@ impl MachineState { } #[inline] - pub extern "C" fn deref(&self, mut addr: HeapCellValue) -> HeapCellValue { + pub fn deref(&self, mut addr: HeapCellValue) -> HeapCellValue { loop { let value = self.store(addr); @@ -1013,7 +1012,7 @@ impl MachineState { } } - pub(super) extern "C" fn write_literal_to_var(&mut self, deref_v: HeapCellValue, lit: HeapCellValue) { + pub(super) fn write_literal_to_var(&mut self, deref_v: HeapCellValue, lit: HeapCellValue) { let store_v = self.store(deref_v); read_heap_cell!(lit, From b9e408ba14c23d066272e8d81566ecfeeb35dc71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 9 Jul 2024 20:55:22 +0200 Subject: [PATCH 19/28] Fix store and deref JIT2 --- src/machine/jit.rs | 455 -------------------------------------------- src/machine/jit2.rs | 25 +-- 2 files changed, 13 insertions(+), 467 deletions(-) delete mode 100644 src/machine/jit.rs diff --git a/src/machine/jit.rs b/src/machine/jit.rs deleted file mode 100644 index 4a0a87fbe..000000000 --- a/src/machine/jit.rs +++ /dev/null @@ -1,455 +0,0 @@ -use std::collections::HashMap; - -use crate::instructions::*; -use crate::machine::*; -use crate::machine::arithmetic_ops::add_jit; - -use cranelift::prelude::*; -use cranelift::prelude::codegen::ir::immediates::Offset32; -use cranelift::prelude::codegen::ir::entities::Value; -use cranelift_jit::{JITBuilder, JITModule}; -use cranelift_module::{Linkage, Module}; - -struct CompileOutput { - module: JITModule, - code_ptr: *const u8, -} - -#[derive(Debug, PartialEq)] -pub enum JitCompileError { - UndefinedPredicate, - InstructionNotImplemented, -} - -pub struct JitMachine { - modules: HashMap, - trampoline: extern "C" fn (*mut MachineState, *mut Registers, *const u8), - offset_interms: usize, - offset_arena: usize, - offset_heap: usize, - write_literal_to_var: *const u8, - deref: *const u8, - store: *const u8, - unify_num: *const u8, - get_number: *const u8, - add: *const u8, - vec_as_ptr: *const u8, - vec_push: *const u8, - number_try_from: *const u8, -} - -impl std::fmt::Debug for JitMachine { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JitMachine") - } -} - -impl JitMachine { - pub fn new() -> Self { - // Build trampoline: from SysV to Tail - let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true"), ("enable_llvm_abi_extensions", "1")], cranelift_module::default_libcall_names()).unwrap(); - - let mut module = JITModule::new(builder); - let pointer_type = module.isa().pointer_type(); - let call_conv = module.isa().default_call_conv(); - let mut ctx = module.make_context(); - let mut func_ctx = FunctionBuilderContext::new(); - - let mut sig = module.make_signature(); - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(pointer_type)); - sig.call_conv = call_conv; - ctx.func.signature = sig.clone(); - - let mut func = module.declare_function("$trampoline", Linkage::Local, &sig).unwrap(); - let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); - let block = fn_builder.create_block(); - fn_builder.append_block_params_for_function_params(block); - fn_builder.switch_to_block(block); - let machine_state = fn_builder.block_params(block)[0]; - let machine_registers = fn_builder.block_params(block)[1]; - let func_addr = fn_builder.block_params(block)[2]; - - let mut sig = module.make_signature(); - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(pointer_type)); - sig.call_conv = isa::CallConv::Tail; - let sig_ref = fn_builder.import_signature(sig); - fn_builder.ins().call_indirect(sig_ref, func_addr, &[machine_state, machine_registers]); - fn_builder.ins().return_(&[]); - fn_builder.seal_all_blocks(); - fn_builder.finalize(); - - module.define_function(func, &mut ctx).unwrap(); - module.clear_context(&mut ctx); - - module.finalize_definitions().unwrap(); - - let code_ptr = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; - - let machine_st = std::mem::MaybeUninit::uninit(); - let machine_st_ptr: *const MachineState = machine_st.as_ptr(); - let machine_st_ptr_u8 = machine_st_ptr as *const u8; - let offset_interms = unsafe { - let interms_ptr = std::ptr::addr_of!((*machine_st_ptr).interms) as *const u8; - interms_ptr.offset_from(machine_st_ptr_u8) as usize - }; - let offset_arena = unsafe { - let arena_ptr = std::ptr::addr_of!((*machine_st_ptr).arena) as *const u8; - arena_ptr.offset_from(machine_st_ptr_u8) as usize - }; - let offset_heap = unsafe { - let heap_ptr = std::ptr::addr_of!((*machine_st_ptr).heap) as *const u8; - heap_ptr.offset_from(machine_st_ptr_u8) as usize - }; - JitMachine { - modules: HashMap::new(), - trampoline: code_ptr, - offset_interms: offset_interms, - offset_arena: offset_arena, - offset_heap: offset_heap, - write_literal_to_var: MachineState::write_literal_to_var as *const u8, - deref: MachineState::deref as *const u8, - store: MachineState::store as *const u8, - unify_num: MachineState::unify_num_jit as *const u8, - get_number: MachineState::get_number as *const u8, - add: add_jit as *const u8, - vec_as_ptr: Vec::::as_ptr as *const u8, - vec_push: Vec::::push as *const u8, - number_try_from: Number::jit_try_from as *const u8, - } - } - - // For now, one module = one predicate - // Access to MachineState via global pointer - // MachineState Registers + ShadowRegisters?? - // Use TAIL call convention - pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError>{ - let mut builder = JITBuilder::with_flags(&[("preserve_frame_pointers", "true"), ("enable_llvm_abi_extensions", "1")], cranelift_module::default_libcall_names()).unwrap(); - builder.symbols(self.modules.iter().map(|(k, v)| (k,v.code_ptr))); - - let mut module = JITModule::new(builder); - let pointer_type = module.isa().pointer_type(); - let call_conv = module.isa().default_call_conv(); - let mut ctx = module.make_context(); - let mut func_ctx = FunctionBuilderContext::new(); - - let mut sig = module.make_signature(); - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(pointer_type)); - sig.call_conv = isa::CallConv::Tail; - ctx.func.signature = sig.clone(); - - let mut func = module.declare_function(name, Linkage::Local, &sig).unwrap(); - - let mut fn_builder = FunctionBuilder::new(&mut ctx.func, &mut func_ctx); - let block = fn_builder.create_block(); - fn_builder.append_block_params_for_function_params(block); - fn_builder.switch_to_block(block); - // TODO: Manage failure - for wam_instr in code { - match wam_instr { - Instruction::Proceed => { - fn_builder.ins().return_(&[]); - fn_builder.seal_all_blocks(); - fn_builder.finalize(); - break; - } - Instruction::ExecuteNamed(arity, pred_name, ..) => { - let machine_state_value = fn_builder.block_params(block)[0]; - let reg_ptr = fn_builder.block_params(block)[1]; - let mut callee_func_sig = module.make_signature(); - callee_func_sig.call_conv = isa::CallConv::Tail; - callee_func_sig.params.push(AbiParam::new(pointer_type)); - callee_func_sig.params.push(AbiParam::new(pointer_type)); - if let Ok(callee_func) = module.declare_function(&format!("{}/{}", pred_name.as_str(), arity), Linkage::Import, &callee_func_sig) { - let func_ref = module.declare_func_in_func(callee_func, fn_builder.func); - fn_builder.ins().return_call(func_ref, &[machine_state_value, reg_ptr]); - fn_builder.seal_all_blocks(); - fn_builder.finalize(); - break; - } else { - return Err(JitCompileError::UndefinedPredicate); - } - } - // TODO Manage RegType - Instruction::GetConstant(_, c, reg) => { - let machine_state_value = fn_builder.block_params(block)[0]; - let reg_ptr = fn_builder.block_params(block)[1]; - let reg_num = reg.reg_num(); - let reg_value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((reg_num as i32)*8)); - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let deref = fn_builder.ins().iconst(types::I64, self.deref as i64); - let deref_call = fn_builder.ins().call_indirect(sig_ref, deref, &[machine_state_value, reg_value]); - let reg_value = fn_builder.inst_results(deref_call)[0]; - let c = unsafe { std::mem::transmute::(u64::from(c)) }; - let c_value = fn_builder.ins().iconst(types::I64, c); - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(types::I64)); - sig.params.push(AbiParam::new(types::I64)); - sig.params.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let write_literal_to_var = fn_builder.ins().iconst(types::I64, self.write_literal_to_var as i64); - fn_builder.ins().call_indirect(sig_ref, write_literal_to_var, &[machine_state_value, reg_value, c_value]); - } - // TODO: Manage RegType - Instruction::GetVariable(norm, arg) => { - let reg_ptr = fn_builder.block_params(block)[1]; - let value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((arg as i32)*8)); - match norm { - RegType::Temp(temp) => { - fn_builder.ins().store(MemFlags::new(), value, reg_ptr, Offset32::new((temp as i32)*8)); - } - _ => unimplemented!() - } - } - // TODO Manage RegType - Instruction::PutConstant(_, c, reg) => { - let reg_ptr = fn_builder.block_params(block)[1]; - let reg_num = reg.reg_num(); - let c = unsafe { std::mem::transmute::(u64::from(c)) }; - let c_value = fn_builder.ins().iconst(types::I64, c); - fn_builder.ins().store(MemFlags::new(), c_value, reg_ptr, Offset32::new((reg_num as i32)*8)); - } - Instruction::PutValue(norm, arg) => { - let reg_ptr = fn_builder.block_params(block)[1]; - match norm { - RegType::Temp(temp) => { - let value = fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((temp as i32)*8)); - fn_builder.ins().store(MemFlags::new(), value, reg_ptr, Offset32::new((arg as i32)*8)); - } - _ => unimplemented!() - } - } - Instruction::SetConstant(c) => { - let machine_st = fn_builder.block_params(block)[0]; - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(pointer_type)); - let sig_ref = fn_builder.import_signature(sig); - let heap = fn_builder.ins().iadd_imm(machine_st, self.offset_heap as i64); - let vec_push = fn_builder.ins().iconst(types::I64, self.vec_push as i64); - let c = unsafe { std::mem::transmute::(u64::from(c)) }; - let c_value = fn_builder.ins().iconst(types::I64, c); - fn_builder.ins().call_indirect(sig_ref, vec_push, &[heap, c_value]); - } - // TODO Fill more cases. Can we optimize the add in some cases to use the Cranelift add? - Instruction::Add(a1, a2, t) => { - let machine_state = fn_builder.block_params(block)[0]; - let reg_ptr = fn_builder.block_params(block)[1]; - let n1 = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, a1); - let n2 = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, a2); - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(types::I128)); - sig.params.push(AbiParam::new(types::I128)); - sig.params.push(AbiParam::new(pointer_type)); - sig.returns.push(AbiParam::new(types::I128)); - let sig_ref = fn_builder.import_signature(sig); - let arena = fn_builder.ins().iadd_imm(machine_state, self.offset_arena as i64); - let add_jit = fn_builder.ins().iconst(types::I64, self.add as i64); - let add_jit_call = fn_builder.ins().call_indirect(sig_ref, add_jit, &[n1, n2, arena]); - let n3 = fn_builder.inst_results(add_jit_call)[0]; - - let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(pointer_type)); - sig.returns.push(AbiParam::new(pointer_type)); - let sig_ref = fn_builder.import_signature(sig); - let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); - let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); - let interms = fn_builder.inst_results(vec_ptr_call)[0]; - fn_builder.ins().store(MemFlags::new(), n3, interms, Offset32::new((t as i32 - 1) * 16)); - } - Instruction::ExecuteIs(r, at) => { - let machine_state = fn_builder.block_params(block)[0]; - let reg_ptr = fn_builder.block_params(block)[1]; - let n1 = self.jit_store_deref_reg(&module, machine_state, reg_ptr, &mut fn_builder, r); - let n = self.jit_get_number(&module, machine_state, reg_ptr, &mut fn_builder, at); - - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I128)); - sig.params.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let unify_num = fn_builder.ins().iconst(types::I64, self.unify_num as i64); - fn_builder.ins().call_indirect(sig_ref, unify_num, &[machine_state, n, n1]); - fn_builder.ins().return_(&[]); - fn_builder.seal_all_blocks(); - fn_builder.finalize(); - break; - } - _ => { - return Err(JitCompileError::InstructionNotImplemented); - } - } - } - module.define_function(func, &mut ctx).unwrap(); - // println!("{}", ctx.func.display()); - module.clear_context(&mut ctx); - - module.finalize_definitions().unwrap(); - let code_ptr = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; - self.modules.insert(name.to_string(), CompileOutput { - module, - code_ptr - }); - Ok(()) - } - - fn jit_store_deref_reg(&self, module: &JITModule, machine_state: Value, reg_ptr: Value, fn_builder: &mut FunctionBuilder, reg: RegType) -> Value { - let pointer_type = module.isa().pointer_type(); - let system_call_conv = module.isa().default_call_conv(); - let n1 = match reg { - RegType::Temp(temp) => { - fn_builder.ins().load(types::I64, MemFlags::new(), reg_ptr, Offset32::new((temp as i32)*8)) - } - _ => unimplemented!() // TODO - }; - - let mut sig = module.make_signature(); - sig.call_conv = system_call_conv; - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let deref = fn_builder.ins().iconst(types::I64, self.deref as i64); - let deref_call = fn_builder.ins().call_indirect(sig_ref, deref, &[machine_state, n1]); - let n1 = fn_builder.inst_results(deref_call)[0]; - - let mut sig = module.make_signature(); - sig.call_conv = system_call_conv; - sig.params.push(AbiParam::new(pointer_type)); - sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I64)); - let sig_ref = fn_builder.import_signature(sig); - let store = fn_builder.ins().iconst(types::I64, self.store as i64); - let store_call = fn_builder.ins().call_indirect(sig_ref, store, &[machine_state, n1]); - fn_builder.inst_results(store_call)[0] - } - - fn jit_get_number(&self, module: &JITModule, machine_state: Value, reg_ptr: Value, fn_builder: &mut FunctionBuilder, at: ArithmeticTerm) -> Value { - let pointer_type = module.isa().pointer_type(); - let call_conv = module.isa().default_call_conv(); - - match at { - ArithmeticTerm::Number(n) => { - let n128 = unsafe { std::mem::transmute::<_, i128>(n) }; - let lo = fn_builder.ins().iconst(types::I64, n128 as i64); - let hi = fn_builder.ins().iconst(types::I64, (n128 >> 64) as i64); - fn_builder.ins().iconcat(lo, hi) - } - ArithmeticTerm::Interm(i) => { - let interms_vec = fn_builder.ins().iadd_imm(machine_state, self.offset_interms as i64); - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(pointer_type)); - sig.returns.push(AbiParam::new(pointer_type)); - let sig_ref = fn_builder.import_signature(sig); - let vec_ptr = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); - let vec_ptr_call = fn_builder.ins().call_indirect(sig_ref, vec_ptr, &[interms_vec]); - let interms = fn_builder.inst_results(vec_ptr_call)[0]; - fn_builder.ins().load(types::I128, MemFlags::new(), interms, Offset32::new((i as i32 - 1) * 16)) - } - ArithmeticTerm::Reg(reg_type) => { - let value = self.jit_store_deref_reg(&module, machine_state, reg_ptr, fn_builder, reg_type); - let mut sig = module.make_signature(); - sig.call_conv = call_conv; - sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I128)); - let sig_ref = fn_builder.import_signature(sig); - let number_try_from = fn_builder.ins().iconst(types::I64, self.number_try_from as i64); - let number_try_from_call = fn_builder.ins().call_indirect(sig_ref, number_try_from, &[value]); - fn_builder.inst_results(number_try_from_call)[0] - } - } - } - - - pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> { - if let Some(output) = self.modules.get(name) { - machine_st.p = machine_st.cp; - machine_st.oip = 0; - machine_st.iip = 0; - // machine_st.num_of_args = arity; - machine_st.num_of_args = 1; - machine_st.b0 = machine_st.b; - - (self.trampoline)(machine_st as *mut MachineState, machine_st.registers.as_ptr() as *mut Registers, output.code_ptr); - Ok(()) - } else { - Err(()) - } - } -} - -// basic. -#[test] -fn jit_test_proceed() { - let mut machine_st = MachineState::new(); - let code = vec![Instruction::Proceed]; - let name = "basic/0"; - - let mut jit = JitMachine::new(); - assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name, &mut machine_st); -} - -// basic. -// simple :- basic. -#[test] -fn jit_test_execute_named() { - let mut machine_st = MachineState::new(); - let mut jit = JitMachine::new(); - let code = vec![Instruction::Proceed]; - let name = "basic/0"; - assert_eq!(jit.compile(name, code), Ok(())); - - let code = vec![Instruction::ExecuteNamed(0, atom!("basic"), CodeIndex::default(&mut Arena::new()))]; - let name = "simple/0"; - assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name, &mut machine_st); -} - -// a(5). -// b :- a(5). -#[test] -fn jit_test_get_constant() { - let mut machine_st = MachineState::new(); - let mut jit = JitMachine::new(); - let code = vec![Instruction::GetConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::Proceed]; - let name = "a/1"; - assert_eq!(jit.compile(name, code), Ok(())); - - let code = vec![Instruction::PutConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut Arena::new()))]; - let name = "b/0"; - assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name, &mut machine_st); - assert_eq!(machine_st.fail, false); -} - -// a(5). -// b :- a(6). -#[test] -fn jit_test_get_constant_fail() { - let mut machine_st = MachineState::new(); - let mut jit = JitMachine::new(); - let code = vec![Instruction::GetConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(5)), RegType::Temp(1)), Instruction::Proceed]; - let name = "a/1"; - assert_eq!(jit.compile(name, code), Ok(())); - - let code = vec![Instruction::PutConstant(Level::Shallow, fixnum_as_cell!(Fixnum::build_with(6)), RegType::Temp(1)), Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut Arena::new()))]; - let name = "b/0"; - assert_eq!(jit.compile(name, code), Ok(())); - jit.exec(name, &mut machine_st); - assert_eq!(machine_st.fail, true); -} diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 56a44ddb1..8441f25b7 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -98,7 +98,7 @@ impl JitMachine { let func = module.declare_function(&format!("$trampoline/{}", n), Linkage::Local, &sig).unwrap(); module.define_function(func, &mut ctx).unwrap(); - println!("{}", ctx.func.display()); + // println!("{}", ctx.func.display()); module.finalize_definitions().unwrap(); module.clear_context(&mut ctx); let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; @@ -210,18 +210,18 @@ impl JitMachine { fn_builder.append_block_param(is_var_block, types::I64); let is_not_var_block = fn_builder.create_block(); fn_builder.append_block_param(is_not_var_block, types::I64); - let tag = fn_builder.ins().band_imm($x, 64); + let tag = fn_builder.ins().ushr_imm($x, 58); let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); fn_builder.ins().brif(is_var, is_var_block, &[$x], is_not_var_block, &[$x]); // is_var fn_builder.switch_to_block(is_var_block); fn_builder.seal_block(is_var_block); let param = fn_builder.block_params(is_var_block)[0]; - let idx = fn_builder.ins().ushr_imm(param, 8); let heap_ptr = heap_as_ptr!(); - let idx_ptr = fn_builder.ins().imul_imm(idx, 8); - let idx_ptr = fn_builder.ins().iadd(heap_ptr, idx_ptr); - let heap_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx_ptr, Offset32::new(0)); + let idx = fn_builder.ins().ishl_imm(param, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + let heap_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); fn_builder.ins().jump(merge_block, &[heap_value]); // is_not_var fn_builder.switch_to_block(is_not_var_block); @@ -248,7 +248,7 @@ impl JitMachine { let addr = fn_builder.block_params(loop_block)[0]; let value = store!(addr); // check if is var - let tag = fn_builder.ins().band_imm(value, 64); + let tag = fn_builder.ins().ushr_imm($x, 58); let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); let not_equal = fn_builder.ins().icmp(IntCC::NotEqual, value, addr); let check = fn_builder.ins().band(is_var, not_equal); @@ -399,7 +399,8 @@ impl JitMachine { } // TODO: Manage RegType Perm // TODO: manage NonVar cases - // TODO: Manage failure + // TODO: Manage unification case + // TODO: manage STORE[addr] is not REF Instruction::GetConstant(_, c, reg) => { let value = match reg { RegType::Temp(x) => { @@ -407,8 +408,8 @@ impl JitMachine { } _ => unimplemented!() }; - //let value = deref!(value); - //let value = store!(value); + let value = deref!(value); + let value = store!(value); // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) let c = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(c.into_bytes())); // Let's suppose STORE[addr] is REF @@ -434,11 +435,11 @@ impl JitMachine { let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); self.module.define_function(func, &mut self.ctx).unwrap(); - println!("{}", self.ctx.func.display()); + // println!("{}", self.ctx.func.display()); self.module.finalize_definitions().unwrap(); let code_ptr = self.module.get_finalized_function(func); self.predicates.insert((name.to_string(), arity), code_ptr); - println!("{}", self.ctx.compiled_code().unwrap().vcode.clone().unwrap()); + // println!("{}", self.ctx.compiled_code().unwrap().vcode.clone().unwrap()); self.module.clear_context(&mut self.ctx); Ok(()) } From d32a574252814089cee4ce9d5ddbe79e00334210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sun, 21 Jul 2024 19:43:16 +0200 Subject: [PATCH 20/28] BIND operation and comments --- src/machine/jit2.rs | 158 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 131 insertions(+), 27 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 8441f25b7..0bfa2a829 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -163,6 +163,8 @@ impl JitMachine { fn_builder.declare_var(s, types::I64); let fail = Variable::new(2); fn_builder.declare_var(fail, types::I8); + let fail_value_init = fn_builder.ins().iconst(types::I8, 0); + fn_builder.def_var(fail, fail_value_init); let mut registers = vec![]; for i in 1..=arity { @@ -201,6 +203,9 @@ impl JitMachine { } } + /* STORE is an operation that abstracts the access to a data cell. It can return the data cell itself + * or some other data cell if it points to some data cell in the heap + */ macro_rules! store { ($x:expr) => { { @@ -236,6 +241,9 @@ impl JitMachine { } } + /* DEREF is an operation that follows a chain of REF until arriving at a self-referential REF + * (unbounded variable) or something that is not a REF + */ macro_rules! deref { ($x:expr) => { { @@ -263,9 +271,54 @@ impl JitMachine { } } + /* BIND is an operation that takes two data cells, one of them being an unbounded REF / Var + * and makes that REF point the other cell on the heap + */ + macro_rules! bind { + ($x:expr, $y:expr) => { + { + let first_var_block = fn_builder.create_block(); + let else_first_var_block = fn_builder.create_block(); + let exit_block = fn_builder.create_block(); + let heap_ptr = heap_as_ptr!(); + // check if x is var + let tag = fn_builder.ins().ushr_imm($x, 58); + let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); + fn_builder.ins().brif(is_var, first_var_block, &[], else_first_var_block, &[]); + // first var block + fn_builder.seal_block(first_var_block); + fn_builder.seal_block(else_first_var_block); + fn_builder.switch_to_block(first_var_block); + // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) + let idx = fn_builder.ins().ishl_imm($x, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + fn_builder.ins().store(MemFlags::trusted(), $y, idx, Offset32::new(0)); + fn_builder.ins().jump(exit_block, &[]); + // else_first_var_block + // suppose the other cell is a var + fn_builder.switch_to_block(else_first_var_block); + let idx = fn_builder.ins().ishl_imm($y, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + fn_builder.ins().store(MemFlags::trusted(), $x, idx, Offset32::new(0)); + fn_builder.ins().jump(exit_block, &[]); + // exit + fn_builder.seal_block(exit_block); + fn_builder.switch_to_block(exit_block); + } + } + } + + // TODO: Unify + for wam_instr in code { match wam_instr { // TODO Missing RegType Perm + /* put_structure is an instruction that puts a new STR in the heap + * (STR cell, plus functor + arity cell) + * It also saves the STR cell into a register + */ Instruction::PutStructure(name, arity, reg) => { let atom_cell = atom_as_cell!(name, arity); let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); @@ -284,12 +337,14 @@ impl JitMachine { } } // TODO Missing RegType Perm + /* set_variable is an instruction that creates a new self-referential REF + * (unbounded variable) in the heap and saves that into a register + */ Instruction::SetVariable(reg) => { let heap_loc_cell = heap_loc_as_cell!(0); let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); let heap_len = heap_len!(); - let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); - let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell); + let heap_loc_cell = fn_builder.ins().bor(heap_len, heap_loc_cell); let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); @@ -301,6 +356,8 @@ impl JitMachine { } } // TODO: Missing RegType Perm + /* set_value is an instruction that pushes a new data cell from a register to the heap + */ Instruction::SetValue(reg) => { let value = match reg { RegType::Temp(x) => { @@ -315,8 +372,14 @@ impl JitMachine { fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); } // TODO: Missing RegType Perm + /* get_structure is an instruction that either matches an existing STR or starts writing a new + * STR into the heap. If the cell passed via register is an unbounded REF, we start WRITE mode + * and unify_variable, unify_value behave similar to set_variable, set_value. + * If it's an existing STR, we check functor and arity, set S pointer and start READ mode, + * in that mod unify_variable and unify_value will follow the unification procedure + */ Instruction::GetStructure(_lvl, name, arity, reg) => { - /*let xi = match reg { + let xi = match reg { RegType::Temp(x) => { registers[x] } @@ -325,32 +388,78 @@ impl JitMachine { let deref = deref!(xi); let store = store!(deref); - let var_block = fn_builder.create_block(); - let str_block = fn_builder.create_block(); - let nostr_block = fn_builder.create_block(); - let other_block = fn_builder.create_block(); + let is_var_block = fn_builder.create_block(); + let else_is_var_block = fn_builder.create_block(); + let is_str_block = fn_builder.create_block(); + let start_read_mode_block = fn_builder.create_block(); + let fail_block = fn_builder.create_block(); let exit_block = fn_builder.create_block(); - let tag = fn_builder.ins().band_imm(store, 64); + let tag = fn_builder.ins().ushr_imm(store, 58); + let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); + fn_builder.ins().brif(is_var, is_var_block, &[], else_is_var_block, &[]); + fn_builder.seal_block(is_var_block); + fn_builder.seal_block(else_is_var_block); + // is_var_block + fn_builder.switch_to_block(is_var_block); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); + let heap_len = heap_len!(); + let heap_len_plus = fn_builder.ins().iadd_imm(heap_len, 1); + let str_cell = fn_builder.ins().bor(heap_len_plus, str_cell); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, str_cell]); + let atom_cell = atom_as_cell!(name, arity); + let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); + bind!(store, str_cell); + let mode_value = fn_builder.ins().iconst(types::I8, 1); + fn_builder.def_var(mode, mode_value); + fn_builder.ins().jump(exit_block, &[]); + + // else_is_var_block + fn_builder.switch_to_block(else_is_var_block); let is_str = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Str as i64); - fn_builder.ins().brif(is_str, str_block, &[], nostr_block, &[]); - fn_builder.seal_block(str_block); - fn_builder.seal_block(nostr_block); - // str block - fn_builder.switch_to_block(str_block); - let a = fn_builder.ins().ushr_imm(store, 8); - let heap_ptr = heap_as_ptr!(); - let a = fn_builder.ins().imul_imm(a, 8); - let idx = fn_builder.ins().iadd(heap_ptr, a); - let atom = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0));*/ - // TODO: Check atom values - // TODO: Continue - + fn_builder.ins().brif(is_str, is_str_block, &[], fail_block, &[]); + fn_builder.seal_block(is_str_block); + // is_str_block + fn_builder.switch_to_block(is_str_block); + let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); + let heap_ptr = heap_as_ptr!(); + let idx = fn_builder.ins().ishl_imm(store, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + let heap_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); + let result = fn_builder.ins().icmp(IntCC::Equal, heap_value, atom); + fn_builder.ins().brif(result, start_read_mode_block, &[], fail_block, &[]); + fn_builder.seal_block(start_read_mode_block); + fn_builder.seal_block(fail_block); + + // start_read_mode_block + fn_builder.switch_to_block(start_read_mode_block); + let s_ptr = fn_builder.ins().iadd_imm(heap_ptr, 8); + fn_builder.def_var(s, s_ptr); + let mode_value = fn_builder.ins().iconst(types::I8, 0); + fn_builder.def_var(mode, mode_value); + fn_builder.ins().jump(exit_block, &[]); + + // fail_block + fn_builder.switch_to_block(fail_block); + let fail_value = fn_builder.ins().iconst(types::I8, 1); + fn_builder.def_var(fail, fail_value); + fn_builder.ins().jump(exit_block, &[]); + + // exit_block + fn_builder.seal_block(exit_block); + fn_builder.switch_to_block(exit_block); } // TODO: Missing RegType Perm. Let's suppose Mode is local to each predicate // TODO: Missing support for PStr and CStr + /* unify_variable is an instruction that in WRITE mode is identical to set_variable but + * in READ mode it reads the data cell from the S pointer to a register + */ Instruction::UnifyVariable(reg) => { let read_block = fn_builder.create_block(); let write_block = fn_builder.create_block(); @@ -412,12 +521,7 @@ impl JitMachine { let value = store!(value); // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) let c = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(c.into_bytes())); - // Let's suppose STORE[addr] is REF - let heap_ptr = heap_as_ptr!(); - let idx = fn_builder.ins().ishl_imm(value, 8); - let idx = fn_builder.ins().ushr_imm(idx, 5); - let idx = fn_builder.ins().iadd(heap_ptr, idx); - fn_builder.ins().store(MemFlags::new(), c, idx, Offset32::new(0)); + bind!(value, c); } Instruction::Proceed => { fn_builder.ins().return_(®isters); From 4a5f5bd1d0b37cf63b95f1bd3e5bd4b240e0ead7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 23 Jul 2024 22:53:47 +0200 Subject: [PATCH 21/28] Unify (TBD), fix stuff, more tests --- src/machine/jit2.rs | 328 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 312 insertions(+), 16 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 0bfa2a829..dedf4d01c 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -9,7 +9,7 @@ use crate::parser::ast::*; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Linkage, Module}; -use cranelift_codegen::Context; +use cranelift_codegen::{ir::stackslot::*, ir::entities::*, Context}; use cranelift::prelude::codegen::ir::immediates::Offset32; use std::collections::HashMap; @@ -167,9 +167,15 @@ impl JitMachine { fn_builder.def_var(fail, fail_value_init); let mut registers = vec![]; - for i in 1..=arity { - let reg = fn_builder.block_params(block)[i]; - registers.push(reg); + // TODO: This could be optimized more, we know the maximum register we're using + for i in 1..MAX_ARITY { + if i <= arity { + let reg = fn_builder.block_params(block)[i]; + registers.push(reg); + } else { + let reg = fn_builder.ins().iconst(types::I64, 0); + registers.push(reg); + } } macro_rules! print_rt { @@ -310,8 +316,141 @@ impl JitMachine { } } + macro_rules! is_var { + ($x:expr) => { + { + print_rt!($x); + let tag = fn_builder.ins().ushr_imm($x, 58); + print_rt!(tag); + dbg!(HeapCellValueTag::Var as i64); + fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64) + } + } + } + + macro_rules! is_str { + ($x:expr) => { + { + print_rt!($x); + let tag = fn_builder.ins().ushr_imm($x, 58); + print_rt!(tag); + dbg!(HeapCellValueTag::Str as i64); + fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Str as i64) + } + } + } + // TODO: Unify + macro_rules! unify { + ($x:expr, $y:expr) => { + { + let loop_block = fn_builder.create_block(); + fn_builder.append_block_param(loop_block, types::I64); + let continue_loop_block = fn_builder.create_block(); + let check_var_block = fn_builder.create_block(); + let is_var_block = fn_builder.create_block(); + let is_str_block = fn_builder.create_block(); + let push_str_block = fn_builder.create_block(); + let push_str_loop_block = fn_builder.create_block(); + let push_str_loop_check_block = fn_builder.create_block(); + let exit_failure = fn_builder.create_block(); + let exit = fn_builder.create_block(); + let pdl = fn_builder.create_dynamic_stack_slot(DynamicStackSlotData::new(StackSlotKind::ExplicitDynamicSlot, DynamicType::from_u32(64))); + let pdl_size = fn_builder.ins().iconst(types::I64, 2); + fn_builder.ins().dynamic_stack_store($x, pdl); + fn_builder.ins().dynamic_stack_store($y, pdl); + fn_builder.ins().jump(loop_block, &[pdl_size]); + // loop_block + fn_builder.switch_to_block(loop_block); + let pdl_size = fn_builder.block_params(loop_block)[0]; + fn_builder.ins().brif(pdl_size, continue_loop_block, &[], exit, &[]); + fn_builder.seal_block(continue_loop_block); + fn_builder.seal_block(exit); + // continue_loop_block + fn_builder.switch_to_block(continue_loop_block); + let d1 = fn_builder.ins().dynamic_stack_load(types::I64, pdl); + let d1 = deref!(d1); + let d2 = fn_builder.ins().dynamic_stack_load(types::I64, pdl); + let d2 = deref!(d2); + let pdl_size = fn_builder.ins().iadd_imm(pdl_size, -2); + let are_equal = fn_builder.ins().icmp(IntCC::Equal, d1, d2); + fn_builder.ins().brif(are_equal, loop_block, &[pdl_size], check_var_block, &[]); + fn_builder.seal_block(check_var_block); + // check_var_block + fn_builder.switch_to_block(check_var_block); + let d1 = store!(d1); + let d2 = store!(d2); + let tag_d1 = fn_builder.ins().ushr_imm(d1, 58); + let is_var_d1 = fn_builder.ins().icmp_imm(IntCC::Equal, tag_d1, HeapCellValueTag::Var as i64); + let tag_d2 = fn_builder.ins().ushr_imm(d2, 58); + let is_var_d2 = fn_builder.ins().icmp_imm(IntCC::Equal, tag_d2, HeapCellValueTag::Var as i64); + let any_var = fn_builder.ins().bor(is_var_d1, is_var_d2); + fn_builder.ins().brif(any_var, is_var_block, &[], is_str_block, &[]); + fn_builder.seal_block(is_var_block); + fn_builder.seal_block(is_str_block); + // is_var_block + fn_builder.switch_to_block(is_var_block); + bind!(d1, d2); + fn_builder.ins().jump(loop_block, &[pdl_size]); + + // is_str_block + fn_builder.switch_to_block(is_str_block); + let heap_ptr = heap_as_ptr!(); + let idx = fn_builder.ins().ishl_imm(d1, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx1 = fn_builder.ins().iadd(heap_ptr, idx); + let heap_d1 = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx1, Offset32::new(0)); + let idx = fn_builder.ins().ishl_imm(d2, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx2 = fn_builder.ins().iadd(heap_ptr, idx); + let heap_d2 = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx2, Offset32::new(0)); + let is_str_equal = fn_builder.ins().icmp(IntCC::Equal, heap_d1, heap_d2); + fn_builder.ins().brif(is_str_equal, push_str_block, &[], exit_failure, &[]); + fn_builder.seal_block(push_str_block); + fn_builder.seal_block(exit_failure); + + // push_str_block + fn_builder.switch_to_block(push_str_block); + let arity = fn_builder.ins().ishl_imm(heap_d1, 8); + let arity = fn_builder.ins().ushr_imm(arity, 54); + let n = fn_builder.ins().iconst(types::I64, 8); + fn_builder.ins().jump(push_str_loop_check_block, &[pdl_size, n]); + fn_builder.seal_block(push_str_loop_check_block); + // push_str_loop_check_block + fn_builder.switch_to_block(push_str_loop_check_block); + let pdl_size = fn_builder.block_params(push_str_loop_check_block)[0]; + let n = fn_builder.block_params(push_str_loop_check_block)[1]; + let n_is_arity = fn_builder.ins().icmp(IntCC::Equal, arity, n); + fn_builder.ins().brif(n_is_arity, loop_block, &[], push_str_loop_block, &[pdl_size, n]); + fn_builder.seal_block(loop_block); + + // push_str_loop_block + fn_builder.switch_to_block(push_str_loop_block); + let n = fn_builder.block_params(push_str_loop_block)[0]; + let v1 = fn_builder.ins().iadd(idx1, n); + let v2 = fn_builder.ins().iadd(idx2, n); + let heap_v1 = fn_builder.ins().load(types::I64, MemFlags::trusted(), v1, Offset32::new(0)); + let heap_v2 = fn_builder.ins().load(types::I64, MemFlags::trusted(), v2, Offset32::new(0)); + fn_builder.ins().dynamic_stack_store(heap_v1, pdl); + fn_builder.ins().dynamic_stack_store(heap_v2, pdl); + let pdl_size = fn_builder.ins().iadd_imm(pdl_size, 2); + let n = fn_builder.ins().iadd_imm(n, 8); + fn_builder.ins().jump(push_str_loop_block, &[pdl_size, n]); + fn_builder.seal_block(push_str_loop_check_block); + + // exit_failure + fn_builder.switch_to_block(exit_failure); + let fail_value = fn_builder.ins().iconst(types::I8, 1); + fn_builder.def_var(fail, fail_value); + fn_builder.ins().jump(exit, &[]); + + // exit + fn_builder.switch_to_block(exit); + } + } + } + for wam_instr in code { match wam_instr { // TODO Missing RegType Perm @@ -327,8 +466,8 @@ impl JitMachine { fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); let heap_len = heap_len!(); - let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); - let str_cell = fn_builder.ins().bor(heap_len_shift, str_cell); + let str_loc = fn_builder.ins().iadd_imm(heap_len, -1); + let str_cell = fn_builder.ins().bor(str_loc, str_cell); match reg { RegType::Temp(x) => { registers[x-1] = str_cell; @@ -381,7 +520,7 @@ impl JitMachine { Instruction::GetStructure(_lvl, name, arity, reg) => { let xi = match reg { RegType::Temp(x) => { - registers[x] + registers[x-1] } _ => unimplemented!() }; @@ -394,9 +533,8 @@ impl JitMachine { let start_read_mode_block = fn_builder.create_block(); let fail_block = fn_builder.create_block(); let exit_block = fn_builder.create_block(); + let is_var = is_var!(store); - let tag = fn_builder.ins().ushr_imm(store, 58); - let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); fn_builder.ins().brif(is_var, is_var_block, &[], else_is_var_block, &[]); fn_builder.seal_block(is_var_block); fn_builder.seal_block(else_is_var_block); @@ -406,9 +544,7 @@ impl JitMachine { let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); let heap_len = heap_len!(); - let heap_len_plus = fn_builder.ins().iadd_imm(heap_len, 1); - let str_cell = fn_builder.ins().bor(heap_len_plus, str_cell); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, str_cell]); + let str_cell = fn_builder.ins().bor(heap_len, str_cell); let atom_cell = atom_as_cell!(name, arity); let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); @@ -419,7 +555,7 @@ impl JitMachine { // else_is_var_block fn_builder.switch_to_block(else_is_var_block); - let is_str = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Str as i64); + let is_str = is_str!(store); fn_builder.ins().brif(is_str, is_str_block, &[], fail_block, &[]); fn_builder.seal_block(is_str_block); @@ -505,6 +641,46 @@ impl JitMachine { fn_builder.switch_to_block(exit_block); fn_builder.seal_block(exit_block); + } + /* unify_value is an instruction that on WRITE mode behaves like set_value, and in READ mode + * executes unification + */ + // TODO: Manage RegType Perm + Instruction::UnifyValue(reg) => { + let reg = match reg { + RegType::Temp(x) => { + registers[x-1] + } + _ => unimplemented!() + }; + let read_block = fn_builder.create_block(); + let write_block = fn_builder.create_block(); + let exit_block = fn_builder.create_block(); + let mode_value = fn_builder.use_var(mode); + fn_builder.ins().brif(mode_value, write_block, &[], read_block, &[]); + fn_builder.seal_block(read_block); + fn_builder.seal_block(write_block); + // read + fn_builder.switch_to_block(read_block); + let heap_ptr = heap_as_ptr!(); + let s_value = fn_builder.use_var(s); + let idx = fn_builder.ins().iadd(heap_ptr, s_value); + let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); + unify!(reg, value); + let s_value = fn_builder.ins().iadd_imm(s_value, 8); + fn_builder.def_var(s, s_value); + fn_builder.ins().jump(exit_block, &[]); + // write + fn_builder.switch_to_block(write_block); + let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); + let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); + fn_builder.ins().jump(exit_block, &[]); + fn_builder.seal_block(exit_block); + + // exit + fn_builder.switch_to_block(exit_block); + } // TODO: Manage RegType Perm // TODO: manage NonVar cases @@ -524,7 +700,7 @@ impl JitMachine { bind!(value, c); } Instruction::Proceed => { - fn_builder.ins().return_(®isters); + fn_builder.ins().return_(®isters[0..arity]); break; }, _ => { @@ -539,11 +715,11 @@ impl JitMachine { let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); self.module.define_function(func, &mut self.ctx).unwrap(); - // println!("{}", self.ctx.func.display()); + println!("{}", self.ctx.func.display()); self.module.finalize_definitions().unwrap(); let code_ptr = self.module.get_finalized_function(func); self.predicates.insert((name.to_string(), arity), code_ptr); - // println!("{}", self.ctx.compiled_code().unwrap().vcode.clone().unwrap()); + println!("{}", self.ctx.compiled_code().unwrap().vcode.clone().unwrap()); self.module.clear_context(&mut self.ctx); Ok(()) } @@ -565,3 +741,123 @@ impl JitMachine { fn print_syscall(value: i64) { println!("{}", value); } + + +#[test] +fn test_put_structure() { + let code = vec![ + Instruction::PutStructure(atom!("f"), 2, RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_put_structure", 1, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_put_structure", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 2)); + machine_st_expected.registers[1] = str_loc_as_cell!(0); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.registers, machine_st_expected.registers); +} + +#[test] +fn test_set_variable() { + let code = vec![ + Instruction::SetVariable(RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_set_variable", 1, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_set_variable", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(heap_loc_as_cell!(0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.registers, machine_st_expected.registers); +} + +#[test] +fn test_set_value() { + let code = vec![ + Instruction::SetValue(RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_set_value", 1, code).unwrap(); + let mut machine_st = MachineState::new(); + machine_st.registers[1] = atom_as_cell!(atom!("a"), 0); + jit.exec("test_set_value", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("a"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); +} + +#[test] +fn test_fig23_wambook() { + let code = vec![ + Instruction::PutStructure(atom!("h"), 2, RegType::Temp(3)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::SetVariable(RegType::Temp(5)), + Instruction::PutStructure(atom!("f"), 1, RegType::Temp(4)), + Instruction::SetValue(RegType::Temp(5)), + Instruction::PutStructure(atom!("p"), 3, RegType::Temp(1)), + Instruction::SetValue(RegType::Temp(2)), + Instruction::SetValue(RegType::Temp(3)), + Instruction::SetValue(RegType::Temp(4)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_fig23_wambook", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_fig23_wambook", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("h"), 2)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); + machine_st_expected.heap.push(atom_as_cell!(atom!("p"), 3)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(str_loc_as_cell!(0)); + machine_st_expected.heap.push(str_loc_as_cell!(3)); + assert_eq!(machine_st.heap, machine_st_expected.heap); +} + +#[test] +fn test_get_structure_read() { + let code = vec![ + Instruction::PutStructure(atom!("f"), 1, RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_get_structure", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_get_structure", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + assert_eq!(machine_st.heap, machine_st_expected.heap); +} + +#[test] +fn test_get_structure_write() { + let code = vec![ + Instruction::SetVariable(RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_get_structure", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_get_structure", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(2)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + assert_eq!(machine_st.heap, machine_st_expected.heap); +} + +// TODO: Continue with more tests From 1ea633dcc89f0fa90ebf0c06d3d442df0ac686df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sun, 28 Jul 2024 11:27:24 +0200 Subject: [PATCH 22/28] Return failure, more tests --- src/machine/jit2.rs | 112 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 22 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index dedf4d01c..2404a35a0 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -61,6 +61,7 @@ impl JitMachine { sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); + sig.returns.push(AbiParam::new(types::I8)); sig.call_conv = call_conv; let mut trampolines = vec![]; @@ -82,17 +83,19 @@ impl JitMachine { params.push(heap); for i in 1..n+1 { jump_sig.params.push(AbiParam::new(types::I64)); - jump_sig.returns.push(AbiParam::new(types::I64)); + // jump_sig.returns.push(AbiParam::new(types::I64)); let reg_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), registers, Offset32::new((i as i32)*8)); params.push(reg_value); } + jump_sig.returns.push(AbiParam::new(types::I8)); let jump_sig_ref = fn_builder.import_signature(jump_sig); let jump_call = fn_builder.ins().call_indirect(jump_sig_ref, func_addr, ¶ms); - for i in 0..n { + /*for i in 0..n { let reg_value = fn_builder.inst_results(jump_call)[i]; fn_builder.ins().store(MemFlags::trusted(), reg_value, registers, Offset32::new(((i as i32) + 1) * 8)); - } - fn_builder.ins().return_(&[]); + }*/ + let fail = fn_builder.inst_results(jump_call)[0]; + fn_builder.ins().return_(&[fail]); fn_builder.seal_block(block); fn_builder.finalize(); @@ -144,8 +147,9 @@ impl JitMachine { sig.params.push(AbiParam::new(types::I64)); for _ in 1..=arity { sig.params.push(AbiParam::new(types::I64)); - sig.returns.push(AbiParam::new(types::I64)); + // sig.returns.push(AbiParam::new(types::I64)); } + sig.returns.push(AbiParam::new(types::I8)); sig.call_conv = isa::CallConv::Tail; self.ctx.func.signature = sig.clone(); self.ctx.set_disasm(true); @@ -319,10 +323,7 @@ impl JitMachine { macro_rules! is_var { ($x:expr) => { { - print_rt!($x); let tag = fn_builder.ins().ushr_imm($x, 58); - print_rt!(tag); - dbg!(HeapCellValueTag::Var as i64); fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64) } } @@ -331,10 +332,7 @@ impl JitMachine { macro_rules! is_str { ($x:expr) => { { - print_rt!($x); let tag = fn_builder.ins().ushr_imm($x, 58); - print_rt!(tag); - dbg!(HeapCellValueTag::Str as i64); fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Str as i64) } } @@ -606,10 +604,8 @@ impl JitMachine { fn_builder.seal_block(write_block); // read fn_builder.switch_to_block(read_block); - let heap_ptr = heap_as_ptr!(); let s_value = fn_builder.use_var(s); - let idx = fn_builder.ins().iadd(heap_ptr, s_value); - let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); + let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), s_value, Offset32::new(0)); let value = deref!(value); match reg { RegType::Temp(x) => { @@ -625,8 +621,7 @@ impl JitMachine { let heap_loc_cell = heap_loc_as_cell!(0); let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); let heap_len = heap_len!(); - let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8); - let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell); + let heap_loc_cell = fn_builder.ins().bor(heap_len, heap_loc_cell); let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); @@ -700,7 +695,10 @@ impl JitMachine { bind!(value, c); } Instruction::Proceed => { - fn_builder.ins().return_(®isters[0..arity]); + // do we really need to return registers? + //fn_builder.ins().return_(®isters[0..arity]); + let fail_value = fn_builder.use_var(fail); + fn_builder.ins().return_(&[fail_value]); break; }, _ => { @@ -728,11 +726,16 @@ impl JitMachine { let Some(predicate) = self.predicates.get(&(name.to_string(), arity)) else { return Err(()); }; - let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec) = unsafe { std::mem::transmute(self.trampolines[arity])}; + let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; let registers = machine_st.registers.as_ptr() as *mut Registers; let heap = &machine_st.heap as *const Vec; - trampoline(*predicate, registers, heap); + let fail = trampoline(*predicate, registers, heap); machine_st.p = machine_st.cp; + machine_st.fail = if fail == 1 { + true + } else { + false + }; Ok(()) } } @@ -755,9 +758,10 @@ fn test_put_structure() { jit.exec("test_put_structure", 1, &mut machine_st).unwrap(); let mut machine_st_expected = MachineState::new(); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 2)); - machine_st_expected.registers[1] = str_loc_as_cell!(0); + // machine_st_expected.registers[1] = str_loc_as_cell!(0); assert_eq!(machine_st.heap, machine_st_expected.heap); - assert_eq!(machine_st.registers, machine_st_expected.registers); + // assert_eq!(machine_st.registers, machine_st_expected.registers); + assert_eq!(machine_st.fail, false); } #[test] @@ -773,7 +777,8 @@ fn test_set_variable() { let mut machine_st_expected = MachineState::new(); machine_st_expected.heap.push(heap_loc_as_cell!(0)); assert_eq!(machine_st.heap, machine_st_expected.heap); - assert_eq!(machine_st.registers, machine_st_expected.registers); + // assert_eq!(machine_st.registers, machine_st_expected.registers); + assert_eq!(machine_st.fail, false); } #[test] @@ -790,6 +795,7 @@ fn test_set_value() { let mut machine_st_expected = MachineState::new(); machine_st_expected.heap.push(atom_as_cell!(atom!("a"), 0)); assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); } #[test] @@ -821,6 +827,7 @@ fn test_fig23_wambook() { machine_st_expected.heap.push(str_loc_as_cell!(0)); machine_st_expected.heap.push(str_loc_as_cell!(3)); assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); } #[test] @@ -839,6 +846,7 @@ fn test_get_structure_read() { machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); machine_st_expected.heap.push(heap_loc_as_cell!(1)); assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); } #[test] @@ -858,6 +866,66 @@ fn test_get_structure_write() { machine_st_expected.heap.push(heap_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + +#[test] +fn test_get_structure_read_fail() { + let code = vec![ + Instruction::PutStructure(atom!("h"), 1, RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_get_structure", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_get_structure", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("h"), 1)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, true); +} + +#[test] +fn test_unify_variable_write() { + let code = vec![ + Instruction::SetVariable(RegType::Temp(1)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::UnifyVariable(RegType::Temp(2)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_variable_write", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_variable_write", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + +#[test] +fn test_unify_variable_read() { + let code = vec![ + Instruction::PutStructure(atom!("f"), 1, RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::UnifyVariable(RegType::Temp(3)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_variable_read", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_variable_read", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); } // TODO: Continue with more tests From c5033cefc4fe3b010c847295cb4fd92950250442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Mon, 12 Aug 2024 20:26:53 +0200 Subject: [PATCH 23/28] I think unification finally works :) --- src/machine/jit2.rs | 594 ++++++++++++++++++++++++++++++-------------- 1 file changed, 414 insertions(+), 180 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 2404a35a0..3fb1f15d1 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -26,12 +26,12 @@ pub struct JitMachine { module: JITModule, ctx: Context, func_ctx: FunctionBuilderContext, - heap_as_ptr: *const u8, - heap_as_ptr_sig: Signature, - heap_push: *const u8, - heap_push_sig: Signature, - heap_len: *const u8, - heap_len_sig: Signature, + vec_as_ptr: *const u8, + vec_as_ptr_sig: Signature, + vec_push: *const u8, + vec_push_sig: Signature, + vec_len: *const u8, + vec_len_sig: Signature, predicates: HashMap<(String, usize), *const u8>, } @@ -46,10 +46,13 @@ impl JitMachine { pub fn new() -> Self { let mut builder = JITBuilder::with_flags(&[ ("preserve_frame_pointers", "true"), - ("enable_llvm_abi_extensions", "1")], + ("enable_llvm_abi_extensions", "1"), + ], cranelift_module::default_libcall_names() ).unwrap(); builder.symbol("print_func", print_syscall as *const u8); + builder.symbol("print_func8", print_syscall8 as *const u8); + builder.symbol("vec_pop", vec_pop_syscall as *const u8); let mut module = JITModule::new(builder); let pointer_type = module.isa().pointer_type(); let call_conv = module.isa().default_call_conv(); @@ -61,6 +64,7 @@ impl JitMachine { sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); sig.returns.push(AbiParam::new(types::I8)); sig.call_conv = call_conv; @@ -76,12 +80,15 @@ impl JitMachine { let func_addr = fn_builder.block_params(block)[0]; let registers = fn_builder.block_params(block)[1]; let heap = fn_builder.block_params(block)[2]; + let pdl = fn_builder.block_params(block)[3]; let mut jump_sig = module.make_signature(); jump_sig.call_conv = isa::CallConv::Tail; jump_sig.params.push(AbiParam::new(types::I64)); + jump_sig.params.push(AbiParam::new(types::I64)); let mut params = vec![]; - params.push(heap); - for i in 1..n+1 { + params.push(heap); + params.push(pdl); + for i in 1..=n { jump_sig.params.push(AbiParam::new(types::I64)); // jump_sig.returns.push(AbiParam::new(types::I64)); let reg_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), registers, Offset32::new((i as i32)*8)); @@ -107,18 +114,18 @@ impl JitMachine { let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; trampolines.push(code_ptr); } - let heap_as_ptr = Vec::::as_ptr as *const u8; - let mut heap_as_ptr_sig = module.make_signature(); - heap_as_ptr_sig.params.push(AbiParam::new(pointer_type)); - heap_as_ptr_sig.returns.push(AbiParam::new(pointer_type)); - let heap_push = Vec::::push as *const u8; - let mut heap_push_sig = module.make_signature(); - heap_push_sig.params.push(AbiParam::new(pointer_type)); - heap_push_sig.params.push(AbiParam::new(types::I64)); - let heap_len = Vec::::len as *const u8; - let mut heap_len_sig = module.make_signature(); - heap_len_sig.params.push(AbiParam::new(pointer_type)); - heap_len_sig.returns.push(AbiParam::new(types::I64)); + let vec_as_ptr = Vec::::as_ptr as *const u8; + let mut vec_as_ptr_sig = module.make_signature(); + vec_as_ptr_sig.params.push(AbiParam::new(pointer_type)); + vec_as_ptr_sig.returns.push(AbiParam::new(pointer_type)); + let vec_push = Vec::::push as *const u8; + let mut vec_push_sig = module.make_signature(); + vec_push_sig.params.push(AbiParam::new(pointer_type)); + vec_push_sig.params.push(AbiParam::new(types::I64)); + let vec_len = Vec::::len as *const u8; + let mut vec_len_sig = module.make_signature(); + vec_len_sig.params.push(AbiParam::new(pointer_type)); + vec_len_sig.returns.push(AbiParam::new(types::I64)); let predicates = HashMap::new(); JitMachine { @@ -126,12 +133,12 @@ impl JitMachine { module, ctx, func_ctx, - heap_as_ptr, - heap_as_ptr_sig, - heap_push, - heap_push_sig, - heap_len, - heap_len_sig, + vec_as_ptr, + vec_as_ptr_sig, + vec_push, + vec_push_sig, + vec_len, + vec_len_sig, predicates } } @@ -141,10 +148,23 @@ impl JitMachine { print_func_sig.params.push(AbiParam::new(types::I64)); let print_func = self.module .declare_function("print_func", Linkage::Import, &print_func_sig) - .unwrap(); + .unwrap(); + let mut print_func_sig8 = self.module.make_signature(); + print_func_sig8.params.push(AbiParam::new(types::I8)); + let print_func8 = self.module + .declare_function("print_func8", Linkage::Import, &print_func_sig8) + .unwrap(); + let mut vec_pop_sig = self.module.make_signature(); + vec_pop_sig.params.push(AbiParam::new(self.module.target_config().pointer_type())); + vec_pop_sig.returns.push(AbiParam::new(types::I64)); + let vec_pop = self.module + .declare_function("vec_pop", Linkage::Import, &vec_pop_sig) + .unwrap(); + let mut sig = self.module.make_signature(); sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); for _ in 1..=arity { sig.params.push(AbiParam::new(types::I64)); // sig.returns.push(AbiParam::new(types::I64)); @@ -161,6 +181,7 @@ impl JitMachine { fn_builder.seal_block(block); let heap = fn_builder.block_params(block)[0]; + let pdl = fn_builder.block_params(block)[1]; let mode = Variable::new(0); fn_builder.declare_var(mode, types::I8); let s = Variable::new(1); @@ -174,7 +195,7 @@ impl JitMachine { // TODO: This could be optimized more, we know the maximum register we're using for i in 1..MAX_ARITY { if i <= arity { - let reg = fn_builder.block_params(block)[i]; + let reg = fn_builder.block_params(block)[i + 1]; registers.push(reg); } else { let reg = fn_builder.ins().iconst(types::I64, 0); @@ -191,23 +212,53 @@ impl JitMachine { } } - macro_rules! heap_len { - () => { - {let sig_ref = fn_builder.import_signature(self.heap_len_sig.clone()); - let heap_len_fn = fn_builder.ins().iconst(types::I64, self.heap_len as i64); - let call_heap_len = fn_builder.ins().call_indirect(sig_ref, heap_len_fn, &[heap]); - let heap_len = fn_builder.inst_results(call_heap_len)[0]; - heap_len} + macro_rules! print_rt8 { + ($x:expr) => { + { + let print_func_fn = self.module.declare_func_in_func(print_func8, &mut fn_builder.func); + fn_builder.ins().call(print_func_fn, &[$x]); + } + } + } + + macro_rules! vec_len { + ($x:expr) => { + {let sig_ref = fn_builder.import_signature(self.vec_len_sig.clone()); + let vec_len_fn = fn_builder.ins().iconst(types::I64, self.vec_len as i64); + let call_vec_len = fn_builder.ins().call_indirect(sig_ref, vec_len_fn, &[$x]); + let vec_len = fn_builder.inst_results(call_vec_len)[0]; + vec_len} } } - macro_rules! heap_as_ptr { - () => { + macro_rules! vec_pop { + ($x:expr) => { { - let sig_ref = fn_builder.import_signature(self.heap_as_ptr_sig.clone()); - let heap_as_ptr_fn = fn_builder.ins().iconst(types::I64, self.heap_as_ptr as i64); - let call_heap_as_ptr = fn_builder.ins().call_indirect(sig_ref, heap_as_ptr_fn, &[heap]); - let heap_ptr = fn_builder.inst_results(call_heap_as_ptr)[0]; + let vec_pop_fn = self.module.declare_func_in_func(vec_pop, &mut fn_builder.func); + let call_vec_pop = fn_builder.ins().call(vec_pop_fn, &[$x]); + let value = fn_builder.inst_results(call_vec_pop)[0]; + value + } + } + } + + macro_rules! vec_push { + ($x:expr, $y:expr) => { + { + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[$x, $y]); + } + } + } + + macro_rules! vec_as_ptr { + ($x:expr) => { + { + let sig_ref = fn_builder.import_signature(self.vec_as_ptr_sig.clone()); + let vec_as_ptr_fn = fn_builder.ins().iconst(types::I64, self.vec_as_ptr as i64); + let call_vec_as_ptr = fn_builder.ins().call_indirect(sig_ref, vec_as_ptr_fn, &[$x]); + let heap_ptr = fn_builder.inst_results(call_vec_as_ptr)[0]; heap_ptr } } @@ -232,7 +283,7 @@ impl JitMachine { fn_builder.switch_to_block(is_var_block); fn_builder.seal_block(is_var_block); let param = fn_builder.block_params(is_var_block)[0]; - let heap_ptr = heap_as_ptr!(); + let heap_ptr = vec_as_ptr!(heap); let idx = fn_builder.ins().ishl_imm(param, 8); let idx = fn_builder.ins().ushr_imm(idx, 5); let idx = fn_builder.ins().iadd(heap_ptr, idx); @@ -290,7 +341,7 @@ impl JitMachine { let first_var_block = fn_builder.create_block(); let else_first_var_block = fn_builder.create_block(); let exit_block = fn_builder.create_block(); - let heap_ptr = heap_as_ptr!(); + let heap_ptr = vec_as_ptr!(heap); // check if x is var let tag = fn_builder.ins().ushr_imm($x, 58); let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); @@ -336,119 +387,138 @@ impl JitMachine { fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Str as i64) } } - } + } - // TODO: Unify macro_rules! unify { ($x:expr, $y:expr) => { { - let loop_block = fn_builder.create_block(); - fn_builder.append_block_param(loop_block, types::I64); - let continue_loop_block = fn_builder.create_block(); - let check_var_block = fn_builder.create_block(); - let is_var_block = fn_builder.create_block(); - let is_str_block = fn_builder.create_block(); - let push_str_block = fn_builder.create_block(); - let push_str_loop_block = fn_builder.create_block(); - let push_str_loop_check_block = fn_builder.create_block(); - let exit_failure = fn_builder.create_block(); let exit = fn_builder.create_block(); - let pdl = fn_builder.create_dynamic_stack_slot(DynamicStackSlotData::new(StackSlotKind::ExplicitDynamicSlot, DynamicType::from_u32(64))); + fn_builder.append_block_param(exit, types::I8); + + // start unification + vec_push!(pdl, $x); + vec_push!(pdl, $y); let pdl_size = fn_builder.ins().iconst(types::I64, 2); - fn_builder.ins().dynamic_stack_store($x, pdl); - fn_builder.ins().dynamic_stack_store($y, pdl); - fn_builder.ins().jump(loop_block, &[pdl_size]); - // loop_block + let fail_status = fn_builder.use_var(fail); + + let pre_loop_block = fn_builder.create_block(); + fn_builder.append_block_param(pre_loop_block, types::I64); // pdl_size + fn_builder.append_block_param(pre_loop_block, types::I8); // fail_status + fn_builder.ins().jump(pre_loop_block, &[pdl_size, fail_status]); + fn_builder.switch_to_block(pre_loop_block); + // pre_loop_block + let pdl_size = fn_builder.block_params(pre_loop_block)[0]; + let fail_status = fn_builder.block_params(pre_loop_block)[1]; + let pdl_size_is_zero = fn_builder.ins().icmp_imm(IntCC::Equal, pdl_size, 0); + let fail_status_is_true = fn_builder.ins().icmp_imm(IntCC::NotEqual, fail_status, 0); + let loop_check = fn_builder.ins().bor(pdl_size_is_zero, fail_status_is_true); + let loop_block = fn_builder.create_block(); + fn_builder.append_block_param(loop_block, types::I64); // pdl_size + fn_builder.append_block_param(loop_block, types::I8); // fail_status + + fn_builder.ins().brif(loop_check, exit, &[fail_status], loop_block, &[pdl_size, fail_status]); + fn_builder.seal_block(exit); + fn_builder.seal_block(loop_block); fn_builder.switch_to_block(loop_block); + // loop_block let pdl_size = fn_builder.block_params(loop_block)[0]; - fn_builder.ins().brif(pdl_size, continue_loop_block, &[], exit, &[]); - fn_builder.seal_block(continue_loop_block); - fn_builder.seal_block(exit); - // continue_loop_block - fn_builder.switch_to_block(continue_loop_block); - let d1 = fn_builder.ins().dynamic_stack_load(types::I64, pdl); + let fail_status = fn_builder.block_params(loop_block)[1]; + let d1 = vec_pop!(pdl); + let d2 = vec_pop!(pdl); let d1 = deref!(d1); - let d2 = fn_builder.ins().dynamic_stack_load(types::I64, pdl); let d2 = deref!(d2); let pdl_size = fn_builder.ins().iadd_imm(pdl_size, -2); let are_equal = fn_builder.ins().icmp(IntCC::Equal, d1, d2); - fn_builder.ins().brif(are_equal, loop_block, &[pdl_size], check_var_block, &[]); - fn_builder.seal_block(check_var_block); - // check_var_block - fn_builder.switch_to_block(check_var_block); - let d1 = store!(d1); - let d2 = store!(d2); - let tag_d1 = fn_builder.ins().ushr_imm(d1, 58); - let is_var_d1 = fn_builder.ins().icmp_imm(IntCC::Equal, tag_d1, HeapCellValueTag::Var as i64); - let tag_d2 = fn_builder.ins().ushr_imm(d2, 58); - let is_var_d2 = fn_builder.ins().icmp_imm(IntCC::Equal, tag_d2, HeapCellValueTag::Var as i64); - let any_var = fn_builder.ins().bor(is_var_d1, is_var_d2); - fn_builder.ins().brif(any_var, is_var_block, &[], is_str_block, &[]); - fn_builder.seal_block(is_var_block); - fn_builder.seal_block(is_str_block); + let unify_two_unequal_cells = fn_builder.create_block(); + fn_builder.append_block_param(unify_two_unequal_cells, types::I64); + fn_builder.append_block_param(unify_two_unequal_cells, types::I8); + + fn_builder.ins().brif(are_equal, pre_loop_block, &[pdl_size, fail_status], unify_two_unequal_cells, &[pdl_size, fail_status]); + fn_builder.seal_block(unify_two_unequal_cells); - // is_var_block - fn_builder.switch_to_block(is_var_block); + // unify_two_unequal_cells + fn_builder.switch_to_block(unify_two_unequal_cells); + let pdl_size = fn_builder.block_params(unify_two_unequal_cells)[0]; + let fail_status = fn_builder.block_params(unify_two_unequal_cells)[1]; + let is_var_d1 = is_var!(d1); + let is_var_d2 = is_var!(d2); + let any_is_var = fn_builder.ins().bor(is_var_d1, is_var_d2); + let bind_var = fn_builder.create_block(); + fn_builder.append_block_param(bind_var, types::I64); + fn_builder.append_block_param(bind_var, types::I8); + let unify_str = fn_builder.create_block(); + fn_builder.append_block_param(unify_str, types::I64); + fn_builder.append_block_param(unify_str, types::I8); + fn_builder.ins().brif(any_is_var, bind_var, &[pdl_size, fail_status], unify_str, &[pdl_size, fail_status]); + fn_builder.seal_block(bind_var); + fn_builder.seal_block(unify_str); + + // bind_var + fn_builder.switch_to_block(bind_var); + let pdl_size = fn_builder.block_params(bind_var)[0]; + let fail_status = fn_builder.block_params(bind_var)[1]; bind!(d1, d2); - fn_builder.ins().jump(loop_block, &[pdl_size]); + fn_builder.ins().jump(pre_loop_block, &[pdl_size, fail_status]); - // is_str_block - fn_builder.switch_to_block(is_str_block); - let heap_ptr = heap_as_ptr!(); - let idx = fn_builder.ins().ishl_imm(d1, 8); - let idx = fn_builder.ins().ushr_imm(idx, 5); - let idx1 = fn_builder.ins().iadd(heap_ptr, idx); - let heap_d1 = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx1, Offset32::new(0)); - let idx = fn_builder.ins().ishl_imm(d2, 8); - let idx = fn_builder.ins().ushr_imm(idx, 5); - let idx2 = fn_builder.ins().iadd(heap_ptr, idx); - let heap_d2 = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx2, Offset32::new(0)); - let is_str_equal = fn_builder.ins().icmp(IntCC::Equal, heap_d1, heap_d2); - fn_builder.ins().brif(is_str_equal, push_str_block, &[], exit_failure, &[]); - fn_builder.seal_block(push_str_block); - fn_builder.seal_block(exit_failure); - - // push_str_block - fn_builder.switch_to_block(push_str_block); - let arity = fn_builder.ins().ishl_imm(heap_d1, 8); + // unify_str + fn_builder.switch_to_block(unify_str); + let pdl_size = fn_builder.block_params(unify_str)[0]; + let fail_status = fn_builder.block_params(unify_str)[1]; + let heap_ptr = vec_as_ptr!(heap); + let idx1 = fn_builder.ins().ishl_imm(d1, 8); + let idx1 = fn_builder.ins().ushr_imm(idx1, 5); + let idx1 = fn_builder.ins().iadd(heap_ptr, idx1); + let d1 = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx1, Offset32::new(0)); + let idx2 = fn_builder.ins().ishl_imm(d2, 8); + let idx2 = fn_builder.ins().ushr_imm(idx2, 5); + let idx2 = fn_builder.ins().iadd(heap_ptr, idx2); + let d2 = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx2, Offset32::new(0)); + let are_atom_arity_equal = fn_builder.ins().icmp(IntCC::Equal, d1, d2); + let fail_status_bad = fn_builder.ins().iconst(types::I8, 1); + let pre_push_subterms = fn_builder.create_block(); + fn_builder.append_block_param(pre_push_subterms, types::I64); + fn_builder.append_block_param(pre_push_subterms, types::I64); + fn_builder.append_block_param(pre_push_subterms, types::I8); + let arity = fn_builder.ins().ishl_imm(d1, 8); let arity = fn_builder.ins().ushr_imm(arity, 54); - let n = fn_builder.ins().iconst(types::I64, 8); - fn_builder.ins().jump(push_str_loop_check_block, &[pdl_size, n]); - fn_builder.seal_block(push_str_loop_check_block); - // push_str_loop_check_block - fn_builder.switch_to_block(push_str_loop_check_block); - let pdl_size = fn_builder.block_params(push_str_loop_check_block)[0]; - let n = fn_builder.block_params(push_str_loop_check_block)[1]; - let n_is_arity = fn_builder.ins().icmp(IntCC::Equal, arity, n); - fn_builder.ins().brif(n_is_arity, loop_block, &[], push_str_loop_block, &[pdl_size, n]); - fn_builder.seal_block(loop_block); + fn_builder.ins().brif(are_atom_arity_equal, pre_push_subterms, &[arity, pdl_size, fail_status], pre_loop_block, &[pdl_size, fail_status_bad]); - // push_str_loop_block - fn_builder.switch_to_block(push_str_loop_block); - let n = fn_builder.block_params(push_str_loop_block)[0]; - let v1 = fn_builder.ins().iadd(idx1, n); - let v2 = fn_builder.ins().iadd(idx2, n); - let heap_v1 = fn_builder.ins().load(types::I64, MemFlags::trusted(), v1, Offset32::new(0)); - let heap_v2 = fn_builder.ins().load(types::I64, MemFlags::trusted(), v2, Offset32::new(0)); - fn_builder.ins().dynamic_stack_store(heap_v1, pdl); - fn_builder.ins().dynamic_stack_store(heap_v2, pdl); + // pre_push_subterms + fn_builder.switch_to_block(pre_push_subterms); + let arity = fn_builder.block_params(pre_push_subterms)[0]; + let pdl_size = fn_builder.block_params(pre_push_subterms)[1]; + let fail_status = fn_builder.block_params(pre_push_subterms)[2]; + let zero_remaining = fn_builder.ins().icmp_imm(IntCC::Equal, arity, 0); + let push_subterms = fn_builder.create_block(); + fn_builder.append_block_param(push_subterms, types::I64); + fn_builder.append_block_param(push_subterms, types::I64); + fn_builder.append_block_param(push_subterms, types::I8); + fn_builder.ins().brif(zero_remaining, pre_loop_block, &[pdl_size, fail_status], push_subterms, &[arity, pdl_size, fail_status]); + fn_builder.seal_block(pre_loop_block); + fn_builder.seal_block(push_subterms); + // push_subterms + fn_builder.switch_to_block(push_subterms); + let d1_next = fn_builder.ins().iadd_imm(idx1, 8); + let d1_next = fn_builder.ins().iadd(heap_ptr, d1_next); + let d1_next = fn_builder.ins().load(types::I64, MemFlags::trusted(), d1_next, Offset32::new(0)); + vec_push!(pdl, d1_next); + let d2_next = fn_builder.ins().iadd_imm(idx2, 8); + let d2_next = fn_builder.ins().iadd(heap_ptr, d2_next); + let d2_next = fn_builder.ins().load(types::I64, MemFlags::trusted(), d2_next, Offset32::new(0)); + vec_push!(pdl, d2_next); let pdl_size = fn_builder.ins().iadd_imm(pdl_size, 2); - let n = fn_builder.ins().iadd_imm(n, 8); - fn_builder.ins().jump(push_str_loop_block, &[pdl_size, n]); - fn_builder.seal_block(push_str_loop_check_block); + let arity = fn_builder.ins().iadd_imm(arity, -1); + fn_builder.ins().jump(pre_push_subterms, &[arity, pdl_size, fail_status]); + fn_builder.seal_block(pre_push_subterms); - // exit_failure - fn_builder.switch_to_block(exit_failure); - let fail_value = fn_builder.ins().iconst(types::I8, 1); - fn_builder.def_var(fail, fail_value); - fn_builder.ins().jump(exit, &[]); - // exit fn_builder.switch_to_block(exit); + let fail_status = fn_builder.block_params(exit)[0]; + fn_builder.def_var(fail, fail_status); } } } - + for wam_instr in code { match wam_instr { // TODO Missing RegType Perm @@ -457,15 +527,16 @@ impl JitMachine { * It also saves the STR cell into a register */ Instruction::PutStructure(name, arity, reg) => { - let atom_cell = atom_as_cell!(name, arity); - let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); - let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); - let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); - let heap_len = heap_len!(); - let str_loc = fn_builder.ins().iadd_imm(heap_len, -1); + let vec_len = vec_len!(heap); + let str_loc = fn_builder.ins().iadd_imm(vec_len, 1); let str_cell = fn_builder.ins().bor(str_loc, str_cell); + + let atom_cell = atom_as_cell!(name, arity); + let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); + vec_push!(heap, str_cell); + vec_push!(heap, atom); + match reg { RegType::Temp(x) => { registers[x-1] = str_cell; @@ -480,11 +551,11 @@ impl JitMachine { Instruction::SetVariable(reg) => { let heap_loc_cell = heap_loc_as_cell!(0); let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); - let heap_len = heap_len!(); - let heap_loc_cell = fn_builder.ins().bor(heap_len, heap_loc_cell); - let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); - let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); + let vec_len = vec_len!(heap); + let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, heap_loc_cell]); match reg { RegType::Temp(x) => { registers[x-1] = heap_loc_cell; @@ -504,9 +575,9 @@ impl JitMachine { }; let value = store!(value); - let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); - let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, value]); } // TODO: Missing RegType Perm /* get_structure is an instruction that either matches an existing STR or starts writing a new @@ -538,14 +609,14 @@ impl JitMachine { fn_builder.seal_block(else_is_var_block); // is_var_block fn_builder.switch_to_block(is_var_block); - let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); - let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes())); - let heap_len = heap_len!(); - let str_cell = fn_builder.ins().bor(heap_len, str_cell); + let vec_len = vec_len!(heap); + let str_cell = fn_builder.ins().bor(vec_len, str_cell); let atom_cell = atom_as_cell!(name, arity); let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, atom]); bind!(store, str_cell); let mode_value = fn_builder.ins().iconst(types::I8, 1); fn_builder.def_var(mode, mode_value); @@ -560,7 +631,7 @@ impl JitMachine { // is_str_block fn_builder.switch_to_block(is_str_block); let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes())); - let heap_ptr = heap_as_ptr!(); + let heap_ptr = vec_as_ptr!(heap); let idx = fn_builder.ins().ishl_imm(store, 8); let idx = fn_builder.ins().ushr_imm(idx, 5); let idx = fn_builder.ins().iadd(heap_ptr, idx); @@ -572,7 +643,7 @@ impl JitMachine { // start_read_mode_block fn_builder.switch_to_block(start_read_mode_block); - let s_ptr = fn_builder.ins().iadd_imm(heap_ptr, 8); + let s_ptr = fn_builder.ins().iadd_imm(idx, 8); fn_builder.def_var(s, s_ptr); let mode_value = fn_builder.ins().iconst(types::I8, 0); fn_builder.def_var(mode, mode_value); @@ -620,11 +691,11 @@ impl JitMachine { fn_builder.switch_to_block(write_block); let heap_loc_cell = heap_loc_as_cell!(0); let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); - let heap_len = heap_len!(); - let heap_loc_cell = fn_builder.ins().bor(heap_len, heap_loc_cell); - let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); - let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]); + let vec_len = vec_len!(heap); + let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, heap_loc_cell]); match reg { RegType::Temp(x) => { registers[x-1] = heap_loc_cell; @@ -657,25 +728,34 @@ impl JitMachine { fn_builder.seal_block(write_block); // read fn_builder.switch_to_block(read_block); - let heap_ptr = heap_as_ptr!(); let s_value = fn_builder.use_var(s); - let idx = fn_builder.ins().iadd(heap_ptr, s_value); - let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0)); + let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), s_value, Offset32::new(0)); unify!(reg, value); let s_value = fn_builder.ins().iadd_imm(s_value, 8); fn_builder.def_var(s, s_value); fn_builder.ins().jump(exit_block, &[]); // write fn_builder.switch_to_block(write_block); - let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone()); - let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64); - fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]); + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, reg]); fn_builder.ins().jump(exit_block, &[]); fn_builder.seal_block(exit_block); // exit fn_builder.switch_to_block(exit_block); + } + Instruction::GetValue(reg1, reg2) => { + let reg1 = match reg1 { + RegType::Temp(x) => { + registers[x-1] + } + _ => unimplemented!() + }; + let reg2 = registers[reg2 - 1]; + unify!(reg1, reg2); + } // TODO: Manage RegType Perm // TODO: manage NonVar cases @@ -726,10 +806,11 @@ impl JitMachine { let Some(predicate) = self.predicates.get(&(name.to_string(), arity)) else { return Err(()); }; - let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; + let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; let registers = machine_st.registers.as_ptr() as *mut Registers; let heap = &machine_st.heap as *const Vec; - let fail = trampoline(*predicate, registers, heap); + let pdl = &machine_st.pdl as *const Vec; + let fail = trampoline(*predicate, registers, heap, pdl); machine_st.p = machine_st.cp; machine_st.fail = if fail == 1 { true @@ -745,6 +826,14 @@ fn print_syscall(value: i64) { println!("{}", value); } +fn print_syscall8(value: i8) { + println!("{}", value); +} + +fn vec_pop_syscall(value: &mut Vec) -> HeapCellValue { + value.pop().unwrap() +} + #[test] fn test_put_structure() { @@ -757,6 +846,7 @@ fn test_put_structure() { let mut machine_st = MachineState::new(); jit.exec("test_put_structure", 1, &mut machine_st).unwrap(); let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 2)); // machine_st_expected.registers[1] = str_loc_as_cell!(0); assert_eq!(machine_st.heap, machine_st_expected.heap); @@ -817,15 +907,18 @@ fn test_fig23_wambook() { let mut machine_st = MachineState::new(); jit.exec("test_fig23_wambook", 0, &mut machine_st).unwrap(); let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("h"), 2)); - machine_st_expected.heap.push(heap_loc_as_cell!(1)); machine_st_expected.heap.push(heap_loc_as_cell!(2)); + machine_st_expected.heap.push(heap_loc_as_cell!(3)); + machine_st_expected.heap.push(str_loc_as_cell!(5)); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); - machine_st_expected.heap.push(heap_loc_as_cell!(2)); + machine_st_expected.heap.push(heap_loc_as_cell!(3)); + machine_st_expected.heap.push(str_loc_as_cell!(8)); machine_st_expected.heap.push(atom_as_cell!(atom!("p"), 3)); - machine_st_expected.heap.push(heap_loc_as_cell!(1)); - machine_st_expected.heap.push(str_loc_as_cell!(0)); - machine_st_expected.heap.push(str_loc_as_cell!(3)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); + machine_st_expected.heap.push(str_loc_as_cell!(1)); + machine_st_expected.heap.push(str_loc_as_cell!(5)); assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, false); } @@ -843,8 +936,9 @@ fn test_get_structure_read() { let mut machine_st = MachineState::new(); jit.exec("test_get_structure", 0, &mut machine_st).unwrap(); let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); - machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, false); } @@ -882,8 +976,9 @@ fn test_get_structure_read_fail() { let mut machine_st = MachineState::new(); jit.exec("test_get_structure", 0, &mut machine_st).unwrap(); let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("h"), 1)); - machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, true); } @@ -909,23 +1004,162 @@ fn test_unify_variable_write() { } #[test] -fn test_unify_variable_read() { +fn test_unify_value_write() { + let code = vec![ + Instruction::SetVariable(RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::UnifyValue(RegType::Temp(2)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_value_write", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_value_write", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(2)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(heap_loc_as_cell!(1)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + +#[test] +fn test_unify_value_read_str() { + let code = vec![ + Instruction::PutStructure(atom!("f"), 1, RegType::Temp(1)), + Instruction::PutStructure(atom!("h"), 0, RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(3)), + Instruction::UnifyValue(RegType::Temp(3)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_value_read", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_value_read", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(str_loc_as_cell!(3)); + machine_st_expected.heap.push(atom_as_cell!(atom!("h"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(3)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + +#[test] +fn test_unify_value_read_str_2() { + let code = vec![ + Instruction::PutStructure(atom!("f"), 0, RegType::Temp(1)), + Instruction::PutStructure(atom!("f"), 0, RegType::Temp(2)), + Instruction::GetValue(RegType::Temp(1), 2), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_value_read", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_value_read", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(3)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); + +} + +#[test] +fn test_unify_value_read_str_3() { let code = vec![ Instruction::PutStructure(atom!("f"), 1, RegType::Temp(1)), Instruction::SetVariable(RegType::Temp(2)), + Instruction::SetVariable(RegType::Temp(3)), Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), - Instruction::UnifyVariable(RegType::Temp(3)), + Instruction::UnifyValue(RegType::Temp(3)), Instruction::Proceed ]; let mut jit = JitMachine::new(); - jit.compile("test_unify_variable_read", 0, code).unwrap(); + jit.compile("test_unify_value_read", 0, code).unwrap(); let mut machine_st = MachineState::new(); - jit.exec("test_unify_variable_read", 0, &mut machine_st).unwrap(); + jit.exec("test_unify_value_read", 0, &mut machine_st).unwrap(); let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); - machine_st_expected.heap.push(heap_loc_as_cell!(1)); + machine_st_expected.heap.push(heap_loc_as_cell!(3)); + machine_st_expected.heap.push(heap_loc_as_cell!(3)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); + +} + + +#[test] +fn test_unify_value_1() { + let code = vec![ + Instruction::SetVariable(RegType::Temp(1)), + Instruction::GetValue(RegType::Temp(1), 1), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_value_read", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_value_read", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(heap_loc_as_cell!(0)); assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, false); } + +#[test] +fn test_unify_value_2() { + let code = vec![ + Instruction::SetVariable(RegType::Temp(1)), + Instruction::SetVariable(RegType::Temp(2)), + Instruction::GetValue(RegType::Temp(1), 2), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_value_read", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_value_read", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(heap_loc_as_cell!(0)); + machine_st_expected.heap.push(heap_loc_as_cell!(0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + + +#[test] +fn test_unify_value_read_fail() { + let code = vec![ + Instruction::PutStructure(atom!("h"), 0, RegType::Temp(2)), + Instruction::PutStructure(atom!("z"), 0, RegType::Temp(3)), + Instruction::PutStructure(atom!("f"), 1, RegType::Temp(1)), + Instruction::SetValue(RegType::Temp(2)), + Instruction::SetValue(RegType::Temp(3)), + Instruction::GetStructure(Level::Shallow, atom!("f"), 1, RegType::Temp(1)), + Instruction::UnifyValue(RegType::Temp(3)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("test_unify_value_read", 0, code).unwrap(); + let mut machine_st = MachineState::new(); + jit.exec("test_unify_value_read", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(str_loc_as_cell!(1)); + machine_st_expected.heap.push(atom_as_cell!(atom!("h"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(3)); + machine_st_expected.heap.push(atom_as_cell!(atom!("z"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(5)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); + machine_st_expected.heap.push(str_loc_as_cell!(1)); + machine_st_expected.heap.push(str_loc_as_cell!(3)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, true); +} // TODO: Continue with more tests From cebab22a037dab3d202ac663dc0e0711561f0eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 13 Aug 2024 16:08:21 +0200 Subject: [PATCH 24/28] Add Call and Execute --- src/machine/jit2.rs | 168 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 141 insertions(+), 27 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 3fb1f15d1..1a73834c6 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -8,7 +8,7 @@ use crate::parser::ast::*; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; -use cranelift_module::{Linkage, Module}; +use cranelift_module::{Linkage, Module, FuncId}; use cranelift_codegen::{ir::stackslot::*, ir::entities::*, Context}; use cranelift::prelude::codegen::ir::immediates::Offset32; @@ -32,7 +32,15 @@ pub struct JitMachine { vec_push_sig: Signature, vec_len: *const u8, vec_len_sig: Signature, - predicates: HashMap<(String, usize), *const u8>, + print_func: FuncId, + print_func8: FuncId, + vec_pop: FuncId, + predicates: HashMap<(String, usize), JitPredicate>, +} + +pub struct JitPredicate { + code_ptr: *const u8, + func_id: FuncId, } impl std::fmt::Debug for JitMachine { @@ -114,6 +122,25 @@ impl JitMachine { let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) }; trampolines.push(code_ptr); } + + let mut print_func_sig = module.make_signature(); + print_func_sig.params.push(AbiParam::new(types::I64)); + let print_func = module + .declare_function("print_func", Linkage::Import, &print_func_sig) + .unwrap(); + let mut print_func_sig8 = module.make_signature(); + print_func_sig8.params.push(AbiParam::new(types::I8)); + let print_func8 = module + .declare_function("print_func8", Linkage::Import, &print_func_sig8) + .unwrap(); + + let mut vec_pop_sig = module.make_signature(); + vec_pop_sig.params.push(AbiParam::new(pointer_type)); + vec_pop_sig.returns.push(AbiParam::new(types::I64)); + let vec_pop = module + .declare_function("vec_pop", Linkage::Import, &vec_pop_sig) + .unwrap(); + let vec_as_ptr = Vec::::as_ptr as *const u8; let mut vec_as_ptr_sig = module.make_signature(); vec_as_ptr_sig.params.push(AbiParam::new(pointer_type)); @@ -139,28 +166,14 @@ impl JitMachine { vec_push_sig, vec_len, vec_len_sig, + print_func, + print_func8, + vec_pop, predicates } } pub fn compile(&mut self, name: &str, arity: usize, code: Code) -> Result<(), JitCompileError> { - let mut print_func_sig = self.module.make_signature(); - print_func_sig.params.push(AbiParam::new(types::I64)); - let print_func = self.module - .declare_function("print_func", Linkage::Import, &print_func_sig) - .unwrap(); - let mut print_func_sig8 = self.module.make_signature(); - print_func_sig8.params.push(AbiParam::new(types::I8)); - let print_func8 = self.module - .declare_function("print_func8", Linkage::Import, &print_func_sig8) - .unwrap(); - - let mut vec_pop_sig = self.module.make_signature(); - vec_pop_sig.params.push(AbiParam::new(self.module.target_config().pointer_type())); - vec_pop_sig.returns.push(AbiParam::new(types::I64)); - let vec_pop = self.module - .declare_function("vec_pop", Linkage::Import, &vec_pop_sig) - .unwrap(); let mut sig = self.module.make_signature(); sig.params.push(AbiParam::new(types::I64)); @@ -206,7 +219,7 @@ impl JitMachine { macro_rules! print_rt { ($x:expr) => { { - let print_func_fn = self.module.declare_func_in_func(print_func, &mut fn_builder.func); + let print_func_fn = self.module.declare_func_in_func(self.print_func, &mut fn_builder.func); fn_builder.ins().call(print_func_fn, &[$x]); } } @@ -215,7 +228,7 @@ impl JitMachine { macro_rules! print_rt8 { ($x:expr) => { { - let print_func_fn = self.module.declare_func_in_func(print_func8, &mut fn_builder.func); + let print_func_fn = self.module.declare_func_in_func(self.print_func8, &mut fn_builder.func); fn_builder.ins().call(print_func_fn, &[$x]); } } @@ -234,7 +247,7 @@ impl JitMachine { macro_rules! vec_pop { ($x:expr) => { { - let vec_pop_fn = self.module.declare_func_in_func(vec_pop, &mut fn_builder.func); + let vec_pop_fn = self.module.declare_func_in_func(self.vec_pop, &mut fn_builder.func); let call_vec_pop = fn_builder.ins().call(vec_pop_fn, &[$x]); let value = fn_builder.inst_results(call_vec_pop)[0]; value @@ -746,6 +759,47 @@ impl JitMachine { fn_builder.switch_to_block(exit_block); } + /* put_variable works similar to set_variable, but it stores the cell in two registers, + * Xi normal register and Ai argument register + */ + Instruction::PutVariable(reg, arg) => { + let heap_loc_cell = heap_loc_as_cell!(0); + let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); + let vec_len = vec_len!(heap); + let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); + let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); + let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); + fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, heap_loc_cell]); + match reg { + RegType::Temp(x) => { + registers[x-1] = heap_loc_cell; + } + _ => unimplemented!() + } + registers[arg - 1] = heap_loc_cell; + } + /* put_value moves the content from Xi to Ai + */ + Instruction::PutValue(reg, arg) => { + registers[arg - 1] = match reg { + RegType::Temp(x) => { + registers[x-1] + } + _ => unimplemented!() + }; + } + /* get_variable moves the content from Ai to Xi + */ + Instruction::GetVariable(reg, arg) => { + match reg { + RegType::Temp(x) => { + registers[x-1] = registers[arg - 1]; + } + _ => unimplemented!() + } + } + /* get_value perform unification between Xi and Ai + */ Instruction::GetValue(reg1, reg2) => { let reg1 = match reg1 { RegType::Temp(x) => { @@ -757,6 +811,41 @@ impl JitMachine { unify!(reg1, reg2); } + Instruction::CallNamed(arity, name, _ ) => { +let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) else { + return Err(JitCompileError::UndefinedPredicate); + }; + let func = self.module.declare_func_in_func(predicate.func_id, fn_builder.func); + let mut args = vec![]; + args.push(heap); + args.push(pdl); + for i in 1..=arity { + args.push(registers[i-1]); + } + let call = fn_builder.ins().call(func, &args); + let fail_status = fn_builder.inst_results(call)[0]; + let exit_early = fn_builder.create_block(); + let resume = fn_builder.create_block(); + fn_builder.ins().brif(fail_status, exit_early, &[], resume, &[]); + fn_builder.seal_block(exit_early); + fn_builder.seal_block(resume); + fn_builder.switch_to_block(exit_early); + fn_builder.ins().return_(&[fail_status]); + fn_builder.switch_to_block(resume); + } + Instruction::ExecuteNamed(arity, name, _) => { + let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) else { + return Err(JitCompileError::UndefinedPredicate); + }; + let func = self.module.declare_func_in_func(predicate.func_id, fn_builder.func); + let mut args = vec![]; + args.push(heap); + args.push(pdl); + for i in 1..=arity { + args.push(registers[i-1]); + } + fn_builder.ins().return_call(func, &args); + } // TODO: Manage RegType Perm // TODO: manage NonVar cases // TODO: Manage unification case @@ -782,6 +871,7 @@ impl JitMachine { break; }, _ => { + dbg!(wam_instr); fn_builder.finalize(); self.module.clear_context(&mut self.ctx); return Err(JitCompileError::InstructionNotImplemented); @@ -791,12 +881,15 @@ impl JitMachine { fn_builder.seal_all_blocks(); fn_builder.finalize(); - let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap(); - self.module.define_function(func, &mut self.ctx).unwrap(); + let func_id = self.module.declare_function(&format!("{}/{}", name, arity), Linkage::Local, &sig).unwrap(); + self.module.define_function(func_id, &mut self.ctx).unwrap(); println!("{}", self.ctx.func.display()); self.module.finalize_definitions().unwrap(); - let code_ptr = self.module.get_finalized_function(func); - self.predicates.insert((name.to_string(), arity), code_ptr); + let code_ptr = self.module.get_finalized_function(func_id); + self.predicates.insert((name.to_string(), arity), JitPredicate { + code_ptr, + func_id + }); println!("{}", self.ctx.compiled_code().unwrap().vcode.clone().unwrap()); self.module.clear_context(&mut self.ctx); Ok(()) @@ -810,7 +903,7 @@ impl JitMachine { let registers = machine_st.registers.as_ptr() as *mut Registers; let heap = &machine_st.heap as *const Vec; let pdl = &machine_st.pdl as *const Vec; - let fail = trampoline(*predicate, registers, heap, pdl); + let fail = trampoline(predicate.code_ptr, registers, heap, pdl); machine_st.p = machine_st.cp; machine_st.fail = if fail == 1 { true @@ -1162,4 +1255,25 @@ fn test_unify_value_read_fail() { assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, true); } + +#[test] +fn test_execute_named() { + let mut machine_st = MachineState::new(); + let code_b = vec![ + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("f"), 0), RegType::Temp(1)), + Instruction::Proceed + ]; + let code_a = vec![ + Instruction::PutVariable(RegType::Temp(1), 1), + Instruction::ExecuteNamed(1, atom!("b"), CodeIndex::default(&mut machine_st.arena)), + ]; + let mut jit = JitMachine::new(); + jit.compile("b", 1, code_b).unwrap(); + jit.compile("a", 0, code_a).unwrap(); + jit.exec("a", 0, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} // TODO: Continue with more tests From 520ede4ee735d3af949f86abea1381f68db5aa3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Fri, 16 Aug 2024 16:09:15 +0200 Subject: [PATCH 25/28] Stack allocation --- src/machine/jit2.rs | 268 ++++++++++++++++++++++++++++---------------- 1 file changed, 169 insertions(+), 99 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 1a73834c6..ae97d5307 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -32,6 +32,8 @@ pub struct JitMachine { vec_push_sig: Signature, vec_len: *const u8, vec_len_sig: Signature, + vec_truncate: *const u8, + vec_truncate_sig: Signature, print_func: FuncId, print_func8: FuncId, vec_pop: FuncId, @@ -73,6 +75,7 @@ impl JitMachine { sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); sig.returns.push(AbiParam::new(types::I8)); sig.call_conv = call_conv; @@ -89,13 +92,19 @@ impl JitMachine { let registers = fn_builder.block_params(block)[1]; let heap = fn_builder.block_params(block)[2]; let pdl = fn_builder.block_params(block)[3]; + let stack = fn_builder.block_params(block)[4]; + let e_pointer = fn_builder.ins().iconst(types::I64, 0); let mut jump_sig = module.make_signature(); jump_sig.call_conv = isa::CallConv::Tail; jump_sig.params.push(AbiParam::new(types::I64)); jump_sig.params.push(AbiParam::new(types::I64)); + jump_sig.params.push(AbiParam::new(types::I64)); + jump_sig.params.push(AbiParam::new(types::I64)); let mut params = vec![]; params.push(heap); params.push(pdl); + params.push(stack); + params.push(e_pointer); for i in 1..=n { jump_sig.params.push(AbiParam::new(types::I64)); // jump_sig.returns.push(AbiParam::new(types::I64)); @@ -153,6 +162,10 @@ impl JitMachine { let mut vec_len_sig = module.make_signature(); vec_len_sig.params.push(AbiParam::new(pointer_type)); vec_len_sig.returns.push(AbiParam::new(types::I64)); + let vec_truncate = Vec::::truncate as *const u8; + let mut vec_truncate_sig = module.make_signature(); + vec_truncate_sig.params.push(AbiParam::new(pointer_type)); + vec_truncate_sig.params.push(AbiParam::new(types::I64)); let predicates = HashMap::new(); JitMachine { @@ -166,6 +179,8 @@ impl JitMachine { vec_push_sig, vec_len, vec_len_sig, + vec_truncate, + vec_truncate_sig, print_func, print_func8, vec_pop, @@ -178,6 +193,8 @@ impl JitMachine { let mut sig = self.module.make_signature(); sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); for _ in 1..=arity { sig.params.push(AbiParam::new(types::I64)); // sig.returns.push(AbiParam::new(types::I64)); @@ -195,6 +212,8 @@ impl JitMachine { let heap = fn_builder.block_params(block)[0]; let pdl = fn_builder.block_params(block)[1]; + let stack = fn_builder.block_params(block)[2]; + let e_pointer = fn_builder.block_params(block)[3]; let mode = Variable::new(0); fn_builder.declare_var(mode, types::I8); let s = Variable::new(1); @@ -203,12 +222,15 @@ impl JitMachine { fn_builder.declare_var(fail, types::I8); let fail_value_init = fn_builder.ins().iconst(types::I8, 0); fn_builder.def_var(fail, fail_value_init); + let e = Variable::new(3); + fn_builder.declare_var(e, types::I64); + fn_builder.def_var(e, e_pointer); let mut registers = vec![]; // TODO: This could be optimized more, we know the maximum register we're using for i in 1..MAX_ARITY { if i <= arity { - let reg = fn_builder.block_params(block)[i + 1]; + let reg = fn_builder.block_params(block)[i + 3]; registers.push(reg); } else { let reg = fn_builder.ins().iconst(types::I64, 0); @@ -265,6 +287,16 @@ impl JitMachine { } } + macro_rules! vec_truncate { + ($x:expr, $y:expr) => { + { + let sig_ref = fn_builder.import_signature(self.vec_truncate_sig.clone()); + let vec_truncate_fn = fn_builder.ins().iconst(types::I64, self.vec_truncate as i64); + fn_builder.ins().call_indirect(sig_ref, vec_truncate_fn, &[$x, $y]); + } + } + } + macro_rules! vec_as_ptr { ($x:expr) => { { @@ -532,9 +564,40 @@ impl JitMachine { } } + macro_rules! read_reg { + ($x:expr) => { + { + match $x { + RegType::Temp(x) => { + registers[x-1] + } + RegType::Perm(y) => { + let idy = ((y as i32) + 1) * 8; + let stack_frame = fn_builder.use_var(e); + fn_builder.ins().load(types::I64, MemFlags::trusted(), stack_frame, Offset32::new(idy)) + } + } + } + } + } + + macro_rules! write_reg { + ($x:expr, $y:expr) => { + match $x { + RegType::Temp(x) => { + registers[x-1] = $y; + } + RegType::Perm(y) => { + let idy = ((y as i32) + 1) * 8; + let stack_frame = fn_builder.use_var(e); + fn_builder.ins().store(MemFlags::trusted(), $y, stack_frame, Offset32::new(idy)); + } + } + } + } + for wam_instr in code { match wam_instr { - // TODO Missing RegType Perm /* put_structure is an instruction that puts a new STR in the heap * (STR cell, plus functor + arity cell) * It also saves the STR cell into a register @@ -550,14 +613,8 @@ impl JitMachine { vec_push!(heap, str_cell); vec_push!(heap, atom); - match reg { - RegType::Temp(x) => { - registers[x-1] = str_cell; - } - _ => unimplemented!() - } + write_reg!(reg, str_cell); } - // TODO Missing RegType Perm /* set_variable is an instruction that creates a new self-referential REF * (unbounded variable) in the heap and saves that into a register */ @@ -566,33 +623,16 @@ impl JitMachine { let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); let vec_len = vec_len!(heap); let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); - let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); - let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); - fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, heap_loc_cell]); - match reg { - RegType::Temp(x) => { - registers[x-1] = heap_loc_cell; - } - _ => unimplemented!() - } + vec_push!(heap, heap_loc_cell); + write_reg!(reg, heap_loc_cell); } - // TODO: Missing RegType Perm /* set_value is an instruction that pushes a new data cell from a register to the heap */ Instruction::SetValue(reg) => { - let value = match reg { - RegType::Temp(x) => { - registers[x-1] - }, - _ => unimplemented!() - }; + let value = read_reg!(reg); let value = store!(value); - - let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); - let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); - fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, value]); + vec_push!(heap, value); } - // TODO: Missing RegType Perm /* get_structure is an instruction that either matches an existing STR or starts writing a new * STR into the heap. If the cell passed via register is an unbounded REF, we start WRITE mode * and unify_variable, unify_value behave similar to set_variable, set_value. @@ -600,12 +640,7 @@ impl JitMachine { * in that mod unify_variable and unify_value will follow the unification procedure */ Instruction::GetStructure(_lvl, name, arity, reg) => { - let xi = match reg { - RegType::Temp(x) => { - registers[x-1] - } - _ => unimplemented!() - }; + let xi = read_reg!(reg); let deref = deref!(xi); let store = store!(deref); @@ -673,7 +708,6 @@ impl JitMachine { fn_builder.switch_to_block(exit_block); } - // TODO: Missing RegType Perm. Let's suppose Mode is local to each predicate // TODO: Missing support for PStr and CStr /* unify_variable is an instruction that in WRITE mode is identical to set_variable but * in READ mode it reads the data cell from the S pointer to a register @@ -691,12 +725,7 @@ impl JitMachine { let s_value = fn_builder.use_var(s); let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), s_value, Offset32::new(0)); let value = deref!(value); - match reg { - RegType::Temp(x) => { - registers[x-1] = value; - }, - _ => unimplemented!() - } + write_reg!(reg, value); let sum_s = fn_builder.ins().iadd_imm(s_value, 8); fn_builder.def_var(s, sum_s); fn_builder.ins().jump(exit_block, &[]); @@ -706,15 +735,8 @@ impl JitMachine { let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); let vec_len = vec_len!(heap); let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); - let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); - let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); - fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, heap_loc_cell]); - match reg { - RegType::Temp(x) => { - registers[x-1] = heap_loc_cell; - } - _ => unimplemented!() - } + vec_push!(heap, heap_loc_cell); + write_reg!(reg, heap_loc_cell); fn_builder.ins().jump(exit_block, &[]); // exit fn_builder.switch_to_block(exit_block); @@ -724,14 +746,8 @@ impl JitMachine { /* unify_value is an instruction that on WRITE mode behaves like set_value, and in READ mode * executes unification */ - // TODO: Manage RegType Perm Instruction::UnifyValue(reg) => { - let reg = match reg { - RegType::Temp(x) => { - registers[x-1] - } - _ => unimplemented!() - }; + let reg = read_reg!(reg); let read_block = fn_builder.create_block(); let write_block = fn_builder.create_block(); let exit_block = fn_builder.create_block(); @@ -749,9 +765,7 @@ impl JitMachine { fn_builder.ins().jump(exit_block, &[]); // write fn_builder.switch_to_block(write_block); - let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); - let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); - fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, reg]); + vec_push!(heap, reg); fn_builder.ins().jump(exit_block, &[]); fn_builder.seal_block(exit_block); @@ -767,58 +781,42 @@ impl JitMachine { let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); let vec_len = vec_len!(heap); let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); - let sig_ref = fn_builder.import_signature(self.vec_push_sig.clone()); - let vec_push_fn = fn_builder.ins().iconst(types::I64, self.vec_push as i64); - fn_builder.ins().call_indirect(sig_ref, vec_push_fn, &[heap, heap_loc_cell]); - match reg { - RegType::Temp(x) => { - registers[x-1] = heap_loc_cell; - } - _ => unimplemented!() - } + vec_push!(heap, heap_loc_cell); + write_reg!(reg, heap_loc_cell); registers[arg - 1] = heap_loc_cell; } /* put_value moves the content from Xi to Ai */ Instruction::PutValue(reg, arg) => { - registers[arg - 1] = match reg { - RegType::Temp(x) => { - registers[x-1] - } - _ => unimplemented!() - }; + registers[arg - 1] = read_reg!(reg); } /* get_variable moves the content from Ai to Xi */ Instruction::GetVariable(reg, arg) => { - match reg { - RegType::Temp(x) => { - registers[x-1] = registers[arg - 1]; - } - _ => unimplemented!() - } + write_reg!(reg, registers[arg - 1]); } /* get_value perform unification between Xi and Ai */ Instruction::GetValue(reg1, reg2) => { - let reg1 = match reg1 { - RegType::Temp(x) => { - registers[x-1] - } - _ => unimplemented!() - }; + let reg1 = read_reg!(reg1); let reg2 = registers[reg2 - 1]; unify!(reg1, reg2); } + /* call executes another predicate in a normal way. It passes all the argument registers + * as function arguments + */ Instruction::CallNamed(arity, name, _ ) => { let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) else { return Err(JitCompileError::UndefinedPredicate); - }; +}; + let e_value = fn_builder.use_var(e); let func = self.module.declare_func_in_func(predicate.func_id, fn_builder.func); let mut args = vec![]; args.push(heap); args.push(pdl); + args.push(stack); + args.push(e_value); for i in 1..=arity { args.push(registers[i-1]); } @@ -833,30 +831,61 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e fn_builder.ins().return_(&[fail_status]); fn_builder.switch_to_block(resume); } + /* execute does a tail call instead of a normal call + */ Instruction::ExecuteNamed(arity, name, _) => { let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) else { return Err(JitCompileError::UndefinedPredicate); }; + let e_value = fn_builder.use_var(e); let func = self.module.declare_func_in_func(predicate.func_id, fn_builder.func); let mut args = vec![]; args.push(heap); args.push(pdl); + args.push(stack); + args.push(e_value); for i in 1..=arity { args.push(registers[i-1]); } fn_builder.ins().return_call(func, &args); } - // TODO: Manage RegType Perm + /* allocate creates a new environment frame (ANDFrame). Every frame contains a pointer + * to the previous frame start, the continuation pointer (in this case we do not store it + * as we have a tree of calls with Cranelift managing this for us), the number of permanent + * variables and the permanent variables themselves. + */ + Instruction::Allocate(n) => { + let new_e_value = vec_len!(stack); + let stack_ptr = vec_as_ptr!(stack); + let new_e_value = fn_builder.ins().imul_imm(new_e_value, 8); + let new_e_value = fn_builder.ins().iadd(stack_ptr, new_e_value); + let e_value = fn_builder.use_var(e); + vec_push!(stack, e_value); + let n_value = fn_builder.ins().iconst(types::I64, n as i64); + vec_push!(stack, n_value); + let zero = fn_builder.ins().iconst(types::I64, 0); + for _ in 0..n { + vec_push!(stack, zero); + } + fn_builder.def_var(e, new_e_value); + } + /* deallocate restores the previous frame, freeing the current frame + */ + Instruction::Deallocate => { + let e_value = fn_builder.use_var(e); + let allocated = fn_builder.ins().load(types::I64, MemFlags::trusted(), e_value, Offset32::new(8)); + let allocated = fn_builder.ins().iadd_imm(allocated, 2); + let new_e_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), e_value, Offset32::new(0)); + let stack_len = vec_len!(stack); + let new_stack_len = fn_builder.ins().isub(stack_len, allocated); + vec_truncate!(stack, new_stack_len); + fn_builder.def_var(e, new_e_value); + } // TODO: manage NonVar cases // TODO: Manage unification case // TODO: manage STORE[addr] is not REF Instruction::GetConstant(_, c, reg) => { - let value = match reg { - RegType::Temp(x) => { - registers[x-1] - } - _ => unimplemented!() - }; + let value = read_reg!(reg); let value = deref!(value); let value = store!(value); // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) @@ -899,11 +928,13 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e let Some(predicate) = self.predicates.get(&(name.to_string(), arity)) else { return Err(()); }; - let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; + let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec, *const Vec, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; let registers = machine_st.registers.as_ptr() as *mut Registers; let heap = &machine_st.heap as *const Vec; let pdl = &machine_st.pdl as *const Vec; - let fail = trampoline(predicate.code_ptr, registers, heap, pdl); + let stack_vec: Vec = Vec::with_capacity(1024); + let stack = &stack_vec as *const Vec; + let fail = trampoline(predicate.code_ptr, registers, heap, pdl, stack); machine_st.p = machine_st.cp; machine_st.fail = if fail == 1 { true @@ -1276,4 +1307,43 @@ fn test_execute_named() { assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, false); } + +#[test] +fn test_allocate() { + let mut machine_st = MachineState::new(); + let code_a = vec![ + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("a"), 0), RegType::Temp(1)), + Instruction::Proceed + ]; + let code_b = vec![ + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("b"), 0), RegType::Temp(1)), + Instruction::Proceed + ]; + let code_c = vec![ + Instruction::Allocate(1), + Instruction::GetVariable(RegType::Perm(1), 2), + Instruction::CallNamed(1, atom!("a"), CodeIndex::default(&mut machine_st.arena)), + Instruction::PutValue(RegType::Perm(1), 1), + Instruction::Deallocate, + Instruction::ExecuteNamed(1, atom!("b"), CodeIndex::default(&mut machine_st.arena)), + ]; + let x = heap_loc_as_cell!(0); + let y = heap_loc_as_cell!(1); + machine_st.registers[1] = x; + machine_st.registers[2] = y; + machine_st.heap.push(x); + machine_st.heap.push(y); + let mut jit = JitMachine::new(); + jit.compile("a", 1, code_a).unwrap(); + jit.compile("b", 1, code_b).unwrap(); + jit.compile("c", 2, code_c).unwrap(); + jit.exec("c", 2, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("a"), 0)); + machine_st_expected.heap.push(atom_as_cell!(atom!("b"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); + + +} // TODO: Continue with more tests From b3b3de307a7847458a2e0b7c967793177a2c133d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Tue, 20 Aug 2024 17:37:32 +0200 Subject: [PATCH 26/28] Exit early --- src/machine/jit2.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index ae97d5307..5d27756a7 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -218,10 +218,6 @@ impl JitMachine { fn_builder.declare_var(mode, types::I8); let s = Variable::new(1); fn_builder.declare_var(s, types::I64); - let fail = Variable::new(2); - fn_builder.declare_var(fail, types::I8); - let fail_value_init = fn_builder.ins().iconst(types::I8, 0); - fn_builder.def_var(fail, fail_value_init); let e = Variable::new(3); fn_builder.declare_var(e, types::I64); fn_builder.def_var(e, e_pointer); @@ -444,7 +440,7 @@ impl JitMachine { vec_push!(pdl, $x); vec_push!(pdl, $y); let pdl_size = fn_builder.ins().iconst(types::I64, 2); - let fail_status = fn_builder.use_var(fail); + let fail_status = fn_builder.ins().iconst(types::I8, 0); let pre_loop_block = fn_builder.create_block(); fn_builder.append_block_param(pre_loop_block, types::I64); // pdl_size @@ -559,7 +555,14 @@ impl JitMachine { // exit fn_builder.switch_to_block(exit); let fail_status = fn_builder.block_params(exit)[0]; - fn_builder.def_var(fail, fail_status); + let exit_early = fn_builder.create_block(); + let exit_normal = fn_builder.create_block(); + fn_builder.ins().brif(fail_status, exit_early, &[], exit_normal, &[]); + fn_builder.seal_block(exit_early); + fn_builder.seal_block(exit_normal); + fn_builder.switch_to_block(exit_early); + fn_builder.ins().return_(&[fail_status]); + fn_builder.switch_to_block(exit_normal); } } } @@ -700,8 +703,7 @@ impl JitMachine { // fail_block fn_builder.switch_to_block(fail_block); let fail_value = fn_builder.ins().iconst(types::I8, 1); - fn_builder.def_var(fail, fail_value); - fn_builder.ins().jump(exit_block, &[]); + fn_builder.ins().return_(&[fail_value]); // exit_block fn_builder.seal_block(exit_block); @@ -801,7 +803,6 @@ impl JitMachine { let reg1 = read_reg!(reg1); let reg2 = registers[reg2 - 1]; unify!(reg1, reg2); - } /* call executes another predicate in a normal way. It passes all the argument registers * as function arguments @@ -895,7 +896,8 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e Instruction::Proceed => { // do we really need to return registers? //fn_builder.ins().return_(®isters[0..arity]); - let fail_value = fn_builder.use_var(fail); + let fail_value = fn_builder.ins().iconst(types::I8, 0); + // if we exit here, it's because it succeeded fn_builder.ins().return_(&[fail_value]); break; }, From 2b906ae73ae2077e7375f485c4f2cc039f8e3023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Sat, 24 Aug 2024 12:37:35 +0200 Subject: [PATCH 27/28] Backtracking and heap reset --- src/machine/jit2.rs | 388 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 349 insertions(+), 39 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index 5d27756a7..d75ce6db3 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -6,7 +6,7 @@ use crate::instructions::*; use crate::machine::*; use crate::parser::ast::*; -use cranelift::prelude::*; +use cranelift::prelude::{*, Value}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Linkage, Module, FuncId}; use cranelift_codegen::{ir::stackslot::*, ir::entities::*, Context}; @@ -34,6 +34,10 @@ pub struct JitMachine { vec_len_sig: Signature, vec_truncate: *const u8, vec_truncate_sig: Signature, + vec_reserve: *const u8, + vec_reserve_sig: Signature, + vec_capacity: *const u8, + vec_capacity_sig: Signature, print_func: FuncId, print_func8: FuncId, vec_pop: FuncId, @@ -45,6 +49,12 @@ pub struct JitPredicate { func_id: FuncId, } +struct Backtrack { + block: Block, + trail_len_at_start: Value, + heap_len_at_start: Value, +} + impl std::fmt::Debug for JitMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "JitMachine") @@ -76,6 +86,7 @@ impl JitMachine { sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); sig.params.push(AbiParam::new(pointer_type)); + sig.params.push(AbiParam::new(pointer_type)); sig.returns.push(AbiParam::new(types::I8)); sig.call_conv = call_conv; @@ -93,18 +104,21 @@ impl JitMachine { let heap = fn_builder.block_params(block)[2]; let pdl = fn_builder.block_params(block)[3]; let stack = fn_builder.block_params(block)[4]; - let e_pointer = fn_builder.ins().iconst(types::I64, 0); + let stack_size = fn_builder.ins().iconst(types::I64, 0); + let trail = fn_builder.block_params(block)[5]; let mut jump_sig = module.make_signature(); jump_sig.call_conv = isa::CallConv::Tail; jump_sig.params.push(AbiParam::new(types::I64)); jump_sig.params.push(AbiParam::new(types::I64)); jump_sig.params.push(AbiParam::new(types::I64)); jump_sig.params.push(AbiParam::new(types::I64)); + jump_sig.params.push(AbiParam::new(types::I64)); let mut params = vec![]; params.push(heap); params.push(pdl); params.push(stack); - params.push(e_pointer); + params.push(stack_size); + params.push(trail); for i in 1..=n { jump_sig.params.push(AbiParam::new(types::I64)); // jump_sig.returns.push(AbiParam::new(types::I64)); @@ -166,6 +180,14 @@ impl JitMachine { let mut vec_truncate_sig = module.make_signature(); vec_truncate_sig.params.push(AbiParam::new(pointer_type)); vec_truncate_sig.params.push(AbiParam::new(types::I64)); + let vec_reserve = Vec::::reserve as *const u8; + let mut vec_reserve_sig = module.make_signature(); + vec_reserve_sig.params.push(AbiParam::new(pointer_type)); + vec_reserve_sig.params.push(AbiParam::new(types::I64)); + let vec_capacity = Vec::::capacity as *const u8; + let mut vec_capacity_sig = module.make_signature(); + vec_capacity_sig.params.push(AbiParam::new(pointer_type)); + vec_capacity_sig.returns.push(AbiParam::new(types::I64)); let predicates = HashMap::new(); JitMachine { @@ -181,6 +203,10 @@ impl JitMachine { vec_len_sig, vec_truncate, vec_truncate_sig, + vec_reserve, + vec_reserve_sig, + vec_capacity, + vec_capacity_sig, print_func, print_func8, vec_pop, @@ -195,6 +221,7 @@ impl JitMachine { sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); + sig.params.push(AbiParam::new(types::I64)); for _ in 1..=arity { sig.params.push(AbiParam::new(types::I64)); // sig.returns.push(AbiParam::new(types::I64)); @@ -213,20 +240,22 @@ impl JitMachine { let heap = fn_builder.block_params(block)[0]; let pdl = fn_builder.block_params(block)[1]; let stack = fn_builder.block_params(block)[2]; - let e_pointer = fn_builder.block_params(block)[3]; + let stack_size = fn_builder.block_params(block)[3]; + let trail = fn_builder.block_params(block)[4]; let mode = Variable::new(0); fn_builder.declare_var(mode, types::I8); let s = Variable::new(1); fn_builder.declare_var(s, types::I64); - let e = Variable::new(3); + let e = Variable::new(2); fn_builder.declare_var(e, types::I64); - fn_builder.def_var(e, e_pointer); + + let mut backtracks: Vec = vec![]; let mut registers = vec![]; // TODO: This could be optimized more, we know the maximum register we're using for i in 1..MAX_ARITY { if i <= arity { - let reg = fn_builder.block_params(block)[i + 3]; + let reg = fn_builder.block_params(block)[i + 4]; registers.push(reg); } else { let reg = fn_builder.ins().iconst(types::I64, 0); @@ -293,6 +322,27 @@ impl JitMachine { } } + macro_rules! vec_reserve { + ($x:expr, $y:expr) => { + { + let sig_ref = fn_builder.import_signature(self.vec_reserve_sig.clone()); + let vec_reserve_fn = fn_builder.ins().iconst(types::I64, self.vec_reserve as i64); + fn_builder.ins().call_indirect(sig_ref, vec_reserve_fn, &[$x, $y]); + } + } + } + + macro_rules! vec_capacity { + ($x:expr, $y:expr) => { + { + let sig_ref = fn_builder.import_signature(self.vec_capacity_sig.clone()); + let vec_capacity_fn = fn_builder.ins().iconst(types::I64, self.vec_capacity as i64); + let call = fn_builder.ins().call_indirect(sig_ref, vec_capacity_fn, &[$x]); + fn_builder.inst_results(call)[0] + } + } + } + macro_rules! vec_as_ptr { ($x:expr) => { { @@ -396,6 +446,7 @@ impl JitMachine { let idx = fn_builder.ins().ushr_imm(idx, 5); let idx = fn_builder.ins().iadd(heap_ptr, idx); fn_builder.ins().store(MemFlags::trusted(), $y, idx, Offset32::new(0)); + trail!($x); fn_builder.ins().jump(exit_block, &[]); // else_first_var_block // suppose the other cell is a var @@ -404,6 +455,7 @@ impl JitMachine { let idx = fn_builder.ins().ushr_imm(idx, 5); let idx = fn_builder.ins().iadd(heap_ptr, idx); fn_builder.ins().store(MemFlags::trusted(), $x, idx, Offset32::new(0)); + trail!($y); fn_builder.ins().jump(exit_block, &[]); // exit fn_builder.seal_block(exit_block); @@ -411,6 +463,68 @@ impl JitMachine { } } } + // TODO: Modify Bind to check address + + // TODO: Manage stack vars + // Stack vars are always cleaned. If there was an allocation after a choicepoint, they need to be removed. If there was + // an allocation before a choicepoint, the vars used need to be cleaned up. + // As we now allocate everything at the beginning, we can just clean everything at RetryMeElse and TrustMe + // However, we need to differentiate StackVar from Vars + macro_rules! trail { + ($x:expr) => { + if !backtracks.is_empty() { + let push_var = fn_builder.create_block(); + let exit = fn_builder.create_block(); + let current_frame = backtracks.get(backtracks.len() - 1).unwrap(); + let idx = fn_builder.ins().ishl_imm($x, 8); + let idx = fn_builder.ins().ushr_imm(idx, 8); + let var_is_older = fn_builder.ins().icmp(IntCC::SignedLessThan, idx, current_frame.heap_len_at_start); + fn_builder.ins().brif(var_is_older, push_var, &[], exit, &[]); + fn_builder.seal_block(push_var); + fn_builder.switch_to_block(push_var); + vec_push!(trail, $x); + fn_builder.ins().jump(exit, &[]); + fn_builder.seal_block(exit); + fn_builder.switch_to_block(exit); + } + } + } + + macro_rules! unwind_trail { + () => { + { + if !backtracks.is_empty() { + let heap_ptr = vec_as_ptr!(heap); + let current_frame = backtracks.get(backtracks.len() - 1).unwrap(); + let trail_len_at_start = current_frame.trail_len_at_start; + let trail_len_now = vec_len!(trail); + let num_items = fn_builder.ins().isub(trail_len_now, trail_len_at_start); + let check_loop = fn_builder.create_block(); + fn_builder.append_block_param(check_loop, types::I64); + let exit = fn_builder.create_block(); + let loop_body = fn_builder.create_block(); + fn_builder.ins().jump(check_loop, &[num_items]); + fn_builder.switch_to_block(check_loop); + let num_items = fn_builder.block_params(check_loop)[0]; + let is_zero = fn_builder.ins().icmp_imm(IntCC::Equal, num_items, 0); + fn_builder.ins().brif(is_zero, exit, &[], loop_body, &[]); + fn_builder.seal_block(exit); + fn_builder.seal_block(loop_body); + fn_builder.switch_to_block(loop_body); + // unwind here + let cell = vec_pop!(trail); + let idx = fn_builder.ins().ishl_imm(cell, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + fn_builder.ins().store(MemFlags::trusted(), cell, idx, Offset32::new(0)); + let num_items = fn_builder.ins().iadd_imm(num_items, -1); + fn_builder.ins().jump(check_loop, &[num_items]); + fn_builder.seal_block(check_loop); + fn_builder.switch_to_block(exit); + } + } + } + } macro_rules! is_var { ($x:expr) => { @@ -421,6 +535,15 @@ impl JitMachine { } } + macro_rules! is_stack_var { + ($x:expr) => { + { + let tag = fn_builder.ins().ushr_imm($x, 58); + fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::StackVar as i64) + } + } + } + macro_rules! is_str { ($x:expr) => { { @@ -575,7 +698,7 @@ impl JitMachine { registers[x-1] } RegType::Perm(y) => { - let idy = ((y as i32) + 1) * 8; + let idy = ((y as i32) - 1) * 8; let stack_frame = fn_builder.use_var(e); fn_builder.ins().load(types::I64, MemFlags::trusted(), stack_frame, Offset32::new(idy)) } @@ -591,7 +714,7 @@ impl JitMachine { registers[x-1] = $y; } RegType::Perm(y) => { - let idy = ((y as i32) + 1) * 8; + let idy = ((y as i32) - 1) * 8; let stack_frame = fn_builder.use_var(e); fn_builder.ins().store(MemFlags::trusted(), $y, stack_frame, Offset32::new(idy)); } @@ -599,6 +722,22 @@ impl JitMachine { } } + // reserve all allocations at once + let mut allocation_size: i64 = 0; + for wam_instr in &code { + if let Instruction::Allocate(n) = wam_instr { + allocation_size = i64::max(allocation_size, *n as i64); + } + } + + let allocation_size_value = fn_builder.ins().iconst(types::I64, allocation_size); + let new_stack_size = fn_builder.ins().iadd(stack_size, allocation_size_value); + vec_reserve!(stack, new_stack_size); + let stack_ptr = vec_as_ptr!(stack); + let stack_size_bytes = fn_builder.ins().imul_imm(stack_size, 8); + let env_ptr = fn_builder.ins().iadd(stack_ptr, stack_size_bytes); + fn_builder.def_var(e, env_ptr); + for wam_instr in code { match wam_instr { /* put_structure is an instruction that puts a new STR in the heap @@ -811,13 +950,13 @@ impl JitMachine { let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) else { return Err(JitCompileError::UndefinedPredicate); }; - let e_value = fn_builder.use_var(e); let func = self.module.declare_func_in_func(predicate.func_id, fn_builder.func); let mut args = vec![]; args.push(heap); args.push(pdl); args.push(stack); - args.push(e_value); + args.push(new_stack_size); + args.push(trail); for i in 1..=arity { args.push(registers[i-1]); } @@ -829,7 +968,13 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e fn_builder.seal_block(exit_early); fn_builder.seal_block(resume); fn_builder.switch_to_block(exit_early); - fn_builder.ins().return_(&[fail_status]); + if backtracks.is_empty() { + fn_builder.ins().return_(&[fail_status]); + } else { + //let last_backtrack = backtracks[backtracks.len() -1 ]; + // TODO: Clean TRAIL + //fn_builder.ins().jump(last_backtrack.next_block, &[]); + } fn_builder.switch_to_block(resume); } /* execute does a tail call instead of a normal call @@ -838,49 +983,45 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) else { return Err(JitCompileError::UndefinedPredicate); }; - let e_value = fn_builder.use_var(e); let func = self.module.declare_func_in_func(predicate.func_id, fn_builder.func); let mut args = vec![]; args.push(heap); args.push(pdl); args.push(stack); - args.push(e_value); + args.push(stack_size); + args.push(trail); for i in 1..=arity { args.push(registers[i-1]); } - fn_builder.ins().return_call(func, &args); + if backtracks.is_empty() { + // we can only optimize tail calls if there's no backtracking in this level + fn_builder.ins().return_call(func, &args); + } else { + // TODO CLEAN TRAIL + let backtrack = backtracks.get(backtracks.len() -1 ).unwrap(); + let call = fn_builder.ins().call(func, &args); + let fail_status = fn_builder.inst_results(call)[0]; + let exit_normal = fn_builder.create_block(); + fn_builder.ins().brif(fail_status, backtrack.block, &[], exit_normal, &[]); + fn_builder.seal_block(exit_normal); + fn_builder.switch_to_block(exit_normal); + let fail_value = fn_builder.ins().iconst(types::I8, 0); + // if we exit here, it's because it succeeded + fn_builder.ins().return_(&[fail_value]); + } } /* allocate creates a new environment frame (ANDFrame). Every frame contains a pointer * to the previous frame start, the continuation pointer (in this case we do not store it * as we have a tree of calls with Cranelift managing this for us), the number of permanent * variables and the permanent variables themselves. */ - Instruction::Allocate(n) => { - let new_e_value = vec_len!(stack); - let stack_ptr = vec_as_ptr!(stack); - let new_e_value = fn_builder.ins().imul_imm(new_e_value, 8); - let new_e_value = fn_builder.ins().iadd(stack_ptr, new_e_value); - let e_value = fn_builder.use_var(e); - vec_push!(stack, e_value); - let n_value = fn_builder.ins().iconst(types::I64, n as i64); - vec_push!(stack, n_value); - let zero = fn_builder.ins().iconst(types::I64, 0); - for _ in 0..n { - vec_push!(stack, zero); - } - fn_builder.def_var(e, new_e_value); + Instruction::Allocate(_n) => { + } /* deallocate restores the previous frame, freeing the current frame */ Instruction::Deallocate => { - let e_value = fn_builder.use_var(e); - let allocated = fn_builder.ins().load(types::I64, MemFlags::trusted(), e_value, Offset32::new(8)); - let allocated = fn_builder.ins().iadd_imm(allocated, 2); - let new_e_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), e_value, Offset32::new(0)); - let stack_len = vec_len!(stack); - let new_stack_len = fn_builder.ins().isub(stack_len, allocated); - vec_truncate!(stack, new_stack_len); - fn_builder.def_var(e, new_e_value); + } // TODO: manage NonVar cases // TODO: Manage unification case @@ -901,6 +1042,38 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e fn_builder.ins().return_(&[fail_value]); break; }, + Instruction::TryMeElse(_offset) => { + let block = fn_builder.create_block(); + let heap_len_at_start = vec_len!(heap); + let trail_len_at_start = vec_len!(trail); + backtracks.push(Backtrack { + block, + heap_len_at_start, + trail_len_at_start, + }); + } + Instruction::RetryMeElse(_offset) => { + let continuation = backtracks.get(backtracks.len() - 1).unwrap(); + fn_builder.seal_block(continuation.block); + fn_builder.switch_to_block(continuation.block); + let heap_len_at_start = continuation.heap_len_at_start; + let trail_len_at_start = continuation.trail_len_at_start; + unwind_trail!(); + backtracks.truncate(backtracks.len() - 1); + let block = fn_builder.create_block(); + backtracks.push(Backtrack { + block, + heap_len_at_start, + trail_len_at_start, + }); + } + Instruction::TrustMe(_) => { + let continuation = backtracks.get(backtracks.len() - 1).unwrap(); + fn_builder.seal_block(continuation.block); + fn_builder.switch_to_block(continuation.block); + unwind_trail!(); + backtracks.truncate(backtracks.len() - 1); + } _ => { dbg!(wam_instr); fn_builder.finalize(); @@ -930,13 +1103,15 @@ let Some(predicate) = self.predicates.get(&(name.as_str().to_string(), arity)) e let Some(predicate) = self.predicates.get(&(name.to_string(), arity)) else { return Err(()); }; - let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec, *const Vec, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; + let trampoline: extern "C" fn (*const u8, *mut Registers, *const Vec, *const Vec, *const Vec, *const Vec) -> u8 = unsafe { std::mem::transmute(self.trampolines[arity])}; let registers = machine_st.registers.as_ptr() as *mut Registers; let heap = &machine_st.heap as *const Vec; let pdl = &machine_st.pdl as *const Vec; let stack_vec: Vec = Vec::with_capacity(1024); let stack = &stack_vec as *const Vec; - let fail = trampoline(predicate.code_ptr, registers, heap, pdl, stack); + let trail_vec: Vec = Vec::with_capacity(1024); + let trail = &trail_vec as *const Vec; + let fail = trampoline(predicate.code_ptr, registers, heap, pdl, stack, trail); machine_st.p = machine_st.cp; machine_st.fail = if fail == 1 { true @@ -1348,4 +1523,139 @@ fn test_allocate() { } + +#[test] +fn test_backtracking_1() { + let mut machine_st = MachineState::new(); + let code_fail = vec![ + Instruction::PutStructure(atom!("f"), 0, RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("h"), 1, RegType::Temp(2)), + Instruction::Proceed + ]; + let code_ok = vec![ + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("a"), 0), RegType::Temp(1)), + Instruction::Proceed + ]; + let code = vec![ + Instruction::Allocate(1), + Instruction::GetVariable(RegType::Perm(1), 1), + Instruction::TryMeElse(4), + Instruction::PutValue(RegType::Perm(1), 1), + Instruction::Deallocate, + Instruction::ExecuteNamed(1, atom!("a"), CodeIndex::default(&mut machine_st.arena)), + Instruction::RetryMeElse(4), + Instruction::PutValue(RegType::Perm(1), 1), + Instruction::Deallocate, + Instruction::ExecuteNamed(1, atom!("b"), CodeIndex::default(&mut machine_st.arena)), + Instruction::TrustMe(0), + Instruction::PutValue(RegType::Perm(1), 1), + Instruction::Deallocate, + Instruction::ExecuteNamed(1, atom!("c"), CodeIndex::default(&mut machine_st.arena)) + ]; + let x = heap_loc_as_cell!(0); + machine_st.registers[1] = x; + machine_st.heap.push(x); + let mut jit = JitMachine::new(); + jit.compile("a", 1, code_fail.clone()).unwrap(); + jit.compile("b", 1, code_fail).unwrap(); + jit.compile("c", 1, code_ok).unwrap(); + jit.compile("d", 1, code).unwrap(); + jit.exec("d", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("a"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(2)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(4)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + + +#[test] +fn test_backtracking_2() { + let mut machine_st = MachineState::new(); + let code_fail = vec![ + Instruction::PutStructure(atom!("f"), 0, RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("h"), 1, RegType::Temp(2)), + Instruction::Proceed + ]; + let code = vec![ + Instruction::TryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("a"), 0), RegType::Temp(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::RetryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("b"), 0), RegType::Temp(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::RetryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("c"), 0), RegType::Temp(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::TrustMe(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("d"), 0), RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("f", 0, code_fail).unwrap(); + jit.compile("a", 1, code).unwrap(); + let x = heap_loc_as_cell!(0); + machine_st.registers[1] = x; + machine_st.heap.push(x); + jit.exec("a", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(atom_as_cell!(atom!("d"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(2)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(4)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(6)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} + +#[test] +fn test_backtracking_3() { + let mut machine_st = MachineState::new(); + let code_fail = vec![ + Instruction::PutStructure(atom!("f"), 0, RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("h"), 1, RegType::Temp(2)), + Instruction::Proceed + ]; + let code = vec![ + Instruction::TryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("a"), 0), RegType::Temp(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::RetryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("b"), 0), RegType::Temp(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::RetryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("c"), 0), RegType::Temp(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::TrustMe(0), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("f", 0, code_fail).unwrap(); + jit.compile("a", 1, code).unwrap(); + let x = heap_loc_as_cell!(0); + machine_st.registers[1] = x; + machine_st.heap.push(x); + jit.exec("a", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + machine_st_expected.heap.push(heap_loc_as_cell!(0)); + machine_st_expected.heap.push(str_loc_as_cell!(2)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(4)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + machine_st_expected.heap.push(str_loc_as_cell!(6)); + machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 0)); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +} +// TODO: Backtracking 4 test stack trail + // TODO: Continue with more tests +// One option is to only deallocate at the end of functions, taking into account how many allocations took place +// Other option is to store deallocation info in backtracks + +// TODO: Move heap, stack, pdl and trail to GlobalValue From 904fba4a0f4e9f84f67fbf2d89485fd37fd90722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Arroyo=20Calle?= Date: Fri, 30 Aug 2024 17:58:42 +0200 Subject: [PATCH 28/28] Stack backtracking I guess --- src/machine/jit2.rs | 149 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 120 insertions(+), 29 deletions(-) diff --git a/src/machine/jit2.rs b/src/machine/jit2.rs index d75ce6db3..ece962726 100644 --- a/src/machine/jit2.rs +++ b/src/machine/jit2.rs @@ -432,38 +432,60 @@ impl JitMachine { let first_var_block = fn_builder.create_block(); let else_first_var_block = fn_builder.create_block(); let exit_block = fn_builder.create_block(); - let heap_ptr = vec_as_ptr!(heap); - // check if x is var - let tag = fn_builder.ins().ushr_imm($x, 58); - let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64); - fn_builder.ins().brif(is_var, first_var_block, &[], else_first_var_block, &[]); + let second_check_block = fn_builder.create_block(); + let third_check_block = fn_builder.create_block(); + let push_var_y = fn_builder.create_block(); + let push_stackvar_y = fn_builder.create_block(); + let heap_ptr = vec_as_ptr!(heap); + let stack_ptr = vec_as_ptr!(stack); + // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) + let idx = fn_builder.ins().ishl_imm($x, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(heap_ptr, idx); + let idy = fn_builder.ins().ishl_imm($y, 8); + let idy = fn_builder.ins().ushr_imm(idy, 5); + let heap_idy = fn_builder.ins().iadd(heap_ptr, idy); + let stack_idy = fn_builder.ins().iadd(stack_ptr, idy); + // first case: X is a var, Y is not a stack var, (if Y is a var too, it should be lower) + // second case: else + let x_is_var = is_var!($x); + let y_is_stack_var = is_stack_var!($y); + let y_is_var = is_var!($y); + let y_is_lower_than_x = fn_builder.ins().icmp(IntCC::SignedLessThan, heap_idy, idx); + let check = fn_builder.ins().band_not(x_is_var, y_is_stack_var); + fn_builder.ins().brif(check, second_check_block, &[], else_first_var_block, &[]); + fn_builder.seal_block(second_check_block); + fn_builder.switch_to_block(second_check_block); + fn_builder.ins().brif(y_is_var, third_check_block, &[], first_var_block, &[]); + fn_builder.seal_block(third_check_block); + fn_builder.switch_to_block(third_check_block); + fn_builder.ins().brif(y_is_lower_than_x, first_var_block, &[], else_first_var_block, &[]); // first var block fn_builder.seal_block(first_var_block); fn_builder.seal_block(else_first_var_block); fn_builder.switch_to_block(first_var_block); - // The order of HeapCellValue is TAG (6), M (1), F (1), VALUE (56) - let idx = fn_builder.ins().ishl_imm($x, 8); - let idx = fn_builder.ins().ushr_imm(idx, 5); - let idx = fn_builder.ins().iadd(heap_ptr, idx); fn_builder.ins().store(MemFlags::trusted(), $y, idx, Offset32::new(0)); trail!($x); fn_builder.ins().jump(exit_block, &[]); // else_first_var_block - // suppose the other cell is a var fn_builder.switch_to_block(else_first_var_block); - let idx = fn_builder.ins().ishl_imm($y, 8); - let idx = fn_builder.ins().ushr_imm(idx, 5); - let idx = fn_builder.ins().iadd(heap_ptr, idx); - fn_builder.ins().store(MemFlags::trusted(), $x, idx, Offset32::new(0)); + fn_builder.ins().brif(y_is_var, push_var_y, &[], push_stackvar_y, &[]); + fn_builder.seal_block(push_var_y); + fn_builder.seal_block(push_stackvar_y); + fn_builder.switch_to_block(push_var_y); + fn_builder.ins().store(MemFlags::trusted(), $x, heap_idy, Offset32::new(0)); + trail!($y); + fn_builder.ins().jump(exit_block, &[]); + fn_builder.switch_to_block(push_stackvar_y); + fn_builder.ins().store(MemFlags::trusted(), $x, stack_idy, Offset32::new(0)); trail!($y); - fn_builder.ins().jump(exit_block, &[]); + fn_builder.ins().jump(exit_block, &[]); // exit fn_builder.seal_block(exit_block); fn_builder.switch_to_block(exit_block); } } } - // TODO: Modify Bind to check address // TODO: Manage stack vars // Stack vars are always cleaned. If there was an allocation after a choicepoint, they need to be removed. If there was @@ -474,16 +496,23 @@ impl JitMachine { ($x:expr) => { if !backtracks.is_empty() { let push_var = fn_builder.create_block(); + let check_push_stack_var = fn_builder.create_block(); let exit = fn_builder.create_block(); let current_frame = backtracks.get(backtracks.len() - 1).unwrap(); let idx = fn_builder.ins().ishl_imm($x, 8); let idx = fn_builder.ins().ushr_imm(idx, 8); let var_is_older = fn_builder.ins().icmp(IntCC::SignedLessThan, idx, current_frame.heap_len_at_start); - fn_builder.ins().brif(var_is_older, push_var, &[], exit, &[]); - fn_builder.seal_block(push_var); + let x_is_var = is_var!($x); + let var_is_older_and_var = fn_builder.ins().band(var_is_older, x_is_var); + fn_builder.ins().brif(var_is_older_and_var, push_var, &[], check_push_stack_var, &[]); + fn_builder.seal_block(check_push_stack_var); fn_builder.switch_to_block(push_var); vec_push!(trail, $x); fn_builder.ins().jump(exit, &[]); + fn_builder.switch_to_block(check_push_stack_var); + let var_is_older = fn_builder.ins().icmp(IntCC::SignedLessThan, idx, stack_size); + fn_builder.ins().brif(var_is_older, push_var, &[], exit, &[]); + fn_builder.seal_block(push_var); fn_builder.seal_block(exit); fn_builder.switch_to_block(exit); } @@ -495,6 +524,7 @@ impl JitMachine { { if !backtracks.is_empty() { let heap_ptr = vec_as_ptr!(heap); + let stack_ptr = vec_as_ptr!(stack); let current_frame = backtracks.get(backtracks.len() - 1).unwrap(); let trail_len_at_start = current_frame.trail_len_at_start; let trail_len_now = vec_len!(trail); @@ -513,12 +543,30 @@ impl JitMachine { fn_builder.switch_to_block(loop_body); // unwind here let cell = vec_pop!(trail); + let is_var = is_var!(cell); + let restore_var = fn_builder.create_block(); + fn_builder.append_block_param(restore_var, types::I64); + let restore_stack_var = fn_builder.create_block(); + fn_builder.append_block_param(restore_stack_var, types::I64); + fn_builder.ins().brif(is_var, restore_var, &[num_items], restore_stack_var, &[num_items]); + fn_builder.seal_block(restore_var); + fn_builder.seal_block(restore_stack_var); + fn_builder.switch_to_block(restore_var); + let num_items = fn_builder.block_params(restore_var)[0]; let idx = fn_builder.ins().ishl_imm(cell, 8); let idx = fn_builder.ins().ushr_imm(idx, 5); let idx = fn_builder.ins().iadd(heap_ptr, idx); fn_builder.ins().store(MemFlags::trusted(), cell, idx, Offset32::new(0)); let num_items = fn_builder.ins().iadd_imm(num_items, -1); fn_builder.ins().jump(check_loop, &[num_items]); + fn_builder.switch_to_block(restore_stack_var); + let num_items = fn_builder.block_params(restore_stack_var)[0]; + let idx = fn_builder.ins().ishl_imm(cell, 8); + let idx = fn_builder.ins().ushr_imm(idx, 5); + let idx = fn_builder.ins().iadd(stack_ptr, idx); + fn_builder.ins().store(MemFlags::trusted(), cell, idx, Offset32::new(0)); + let num_items = fn_builder.ins().iadd_imm(num_items, -1); + fn_builder.ins().jump(check_loop, &[num_items]); fn_builder.seal_block(check_loop); fn_builder.switch_to_block(exit); } @@ -918,13 +966,23 @@ impl JitMachine { * Xi normal register and Ai argument register */ Instruction::PutVariable(reg, arg) => { - let heap_loc_cell = heap_loc_as_cell!(0); - let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); - let vec_len = vec_len!(heap); - let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); - vec_push!(heap, heap_loc_cell); - write_reg!(reg, heap_loc_cell); - registers[arg - 1] = heap_loc_cell; + match reg { + RegType::Temp(_) => { + let heap_loc_cell = heap_loc_as_cell!(0); + let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes())); + let vec_len = vec_len!(heap); + let heap_loc_cell = fn_builder.ins().bor(vec_len, heap_loc_cell); + vec_push!(heap, heap_loc_cell); + write_reg!(reg, heap_loc_cell); + registers[arg - 1] = heap_loc_cell; + } + RegType::Perm(y) => { + let stack_loc_cell = stack_loc_as_cell!(y - 1); + let stack_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(stack_loc_cell.into_bytes())); + write_reg!(reg, stack_loc_cell); + registers[arg - 1] = stack_loc_cell; + } + } } /* put_value moves the content from Xi to Ai */ @@ -1389,8 +1447,8 @@ fn test_unify_value_read_str_3() { let mut machine_st_expected = MachineState::new(); machine_st_expected.heap.push(str_loc_as_cell!(1)); machine_st_expected.heap.push(atom_as_cell!(atom!("f"), 1)); - machine_st_expected.heap.push(heap_loc_as_cell!(3)); - machine_st_expected.heap.push(heap_loc_as_cell!(3)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); + machine_st_expected.heap.push(heap_loc_as_cell!(2)); assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, false); @@ -1652,10 +1710,43 @@ fn test_backtracking_3() { assert_eq!(machine_st.heap, machine_st_expected.heap); assert_eq!(machine_st.fail, false); } + +/*#[test] +fn test_backtracking_4() { + let mut machine_st = MachineState::new(); + let code_fail = vec![ + Instruction::PutStructure(atom!("f"), 0, RegType::Temp(2)), + Instruction::GetStructure(Level::Shallow, atom!("h"), 1, RegType::Temp(2)), + Instruction::Proceed + ]; + let code = vec![ + Instruction::Allocate(1), + Instruction::PutVariable(RegType::Perm(1), 1), + Instruction::TryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("a"), 0), RegType::Perm(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::RetryMeElse(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("b"), 0), RegType::Perm(1)), + Instruction::ExecuteNamed(0, atom!("f"), CodeIndex::default(&mut machine_st.arena)), + Instruction::TrustMe(0), + Instruction::GetConstant(Level::Shallow, atom_as_cell!(atom!("c"), 0), RegType::Perm(1)), + Instruction::GetVariable(RegType::Temp(1), 1), + Instruction::SetValue(RegType::Temp(1)), + Instruction::Proceed + ]; + let mut jit = JitMachine::new(); + jit.compile("f", 0, code_fail).unwrap(); + jit.compile("a", 1, code).unwrap(); + let x = heap_loc_as_cell!(0); + machine_st.registers[1] = x; + machine_st.heap.push(x); + jit.exec("a", 1, &mut machine_st).unwrap(); + let mut machine_st_expected = MachineState::new(); + assert_eq!(machine_st.heap, machine_st_expected.heap); + assert_eq!(machine_st.fail, false); +}*/ // TODO: Backtracking 4 test stack trail // TODO: Continue with more tests -// One option is to only deallocate at the end of functions, taking into account how many allocations took place -// Other option is to store deallocation info in backtracks // TODO: Move heap, stack, pdl and trail to GlobalValue