Skip to content

Commit

Permalink
Add more instructions to JIT2
Browse files Browse the repository at this point in the history
  • Loading branch information
aarroyoc committed Jul 2, 2024
1 parent e8cbd2e commit 3dc92f2
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 12 deletions.
243 changes: 235 additions & 8 deletions src/machine/jit2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
use cranelift_codegen::Context;
use cranelift::prelude::codegen::ir::immediates::Offset32;
use cranelift::prelude::codegen::ir::entities::Value;

use std::ops::Index;

#[derive(Debug, PartialEq)]
pub enum JitCompileError {
Expand All @@ -24,6 +27,12 @@ pub struct JitMachine {
module: JITModule,
ctx: Context,
func_ctx: FunctionBuilderContext,
heap_as_ptr: *const u8,
heap_as_ptr_sig: Signature,
heap_push: *const u8,
heap_push_sig: Signature,
heap_len: *const u8,
heap_len_sig: Signature,
}

impl std::fmt::Debug for JitMachine {
Expand Down Expand Up @@ -85,27 +94,247 @@ impl JitMachine {
let code_ptr: *const u8 = unsafe { std::mem::transmute(module.get_finalized_function(func)) };
trampolines.push(code_ptr);
}


let heap_as_ptr = Vec::<HeapCellValue>::as_ptr as *const u8;
let mut heap_as_ptr_sig = module.make_signature();
heap_as_ptr_sig.params.push(AbiParam::new(pointer_type));
heap_as_ptr_sig.returns.push(AbiParam::new(pointer_type));
let heap_push = Vec::<HeapCellValue>::push as *const u8;
let mut heap_push_sig = module.make_signature();
heap_push_sig.params.push(AbiParam::new(pointer_type));
heap_push_sig.params.push(AbiParam::new(types::I64));
let heap_len = Vec::<HeapCellValue>::len as *const u8;
let mut heap_len_sig = module.make_signature();
heap_len_sig.params.push(AbiParam::new(pointer_type));
heap_len_sig.returns.push(AbiParam::new(types::I64));
JitMachine {
trampolines,
module,
ctx,
func_ctx,
heap_as_ptr,
heap_as_ptr_sig,
heap_push,
heap_push_sig,
heap_len,
heap_len_sig,
}
}

// TODO: Compile taking into account arity
pub fn compile(&mut self, name: &str, code: Code) -> Result<(), JitCompileError> {
pub fn compile(&mut self, name: &str, arity: usize, code: Code) -> Result<(), JitCompileError> {
let mut sig = self.module.make_signature();
sig.params.push(AbiParam::new(types::I64));
for _ in 1..=arity {
sig.params.push(AbiParam::new(types::I64));
sig.returns.push(AbiParam::new(types::I64));
}
sig.call_conv = isa::CallConv::Tail;
self.ctx.func.signature = sig.clone();

let mut fn_builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.func_ctx);
let block = fn_builder.create_block();
fn_builder.append_block_params_for_function_params(block);
fn_builder.switch_to_block(block);
fn_builder.seal_block(block);

let heap = fn_builder.block_params(block)[0];
let mode = Variable::new(0);
fn_builder.declare_var(mode, types::I8);
let s = Variable::new(1);
fn_builder.declare_var(s, types::I64);
let fail = Variable::new(2);
fn_builder.declare_var(fail, types::I8);

let mut registers = vec![];
for i in 1..=arity {
let reg = fn_builder.block_params(block)[i];
registers.push(reg);
}

macro_rules! heap_len {
() => {
{let sig_ref = fn_builder.import_signature(self.heap_len_sig.clone());
let heap_len_fn = fn_builder.ins().iconst(types::I64, self.heap_len as i64);
let call_heap_len = fn_builder.ins().call_indirect(sig_ref, heap_len_fn, &[heap]);
let heap_len = fn_builder.inst_results(call_heap_len)[0];
heap_len}
}
}

macro_rules! heap_as_ptr {
() => {
{
let sig_ref = fn_builder.import_signature(self.heap_as_ptr_sig.clone());
let heap_as_ptr_fn = fn_builder.ins().iconst(types::I64, self.heap_as_ptr as i64);
let call_heap_as_ptr = fn_builder.ins().call_indirect(sig_ref, heap_as_ptr_fn, &[heap]);
let heap_ptr = fn_builder.inst_results(call_heap_as_ptr)[0];
heap_ptr
}
}
}

macro_rules! store {
($x:expr) => {
{
let merge_block = fn_builder.create_block();
fn_builder.append_block_param(merge_block, types::I64);
let is_var_block = fn_builder.create_block();
fn_builder.append_block_param(is_var_block, types::I64);
let is_not_var_block = fn_builder.create_block();
fn_builder.append_block_param(is_not_var_block, types::I64);
let tag = fn_builder.ins().band_imm($x, 64);
let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64);
fn_builder.ins().brif(is_var, is_var_block, &[$x], is_not_var_block, &[$x]);
// is_var
fn_builder.switch_to_block(is_var_block);
fn_builder.seal_block(is_var_block);
let param = fn_builder.block_params(is_var_block)[0];
let idx = fn_builder.ins().ushr_imm(param, 8);
let heap_ptr = heap_as_ptr!();
let idx_ptr = fn_builder.ins().imul_imm(idx, 8);
let idx_ptr = fn_builder.ins().iadd(heap_ptr, idx_ptr);
let heap_value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx_ptr, Offset32::new(0));
fn_builder.ins().jump(merge_block, &[heap_value]);
// is_not_var
fn_builder.switch_to_block(is_not_var_block);
fn_builder.seal_block(is_not_var_block);
let param = fn_builder.block_params(is_not_var_block)[0];
fn_builder.ins().jump(merge_block, &[param]);
// merge
fn_builder.switch_to_block(merge_block);
fn_builder.seal_block(merge_block);
fn_builder.block_params(merge_block)[0]
}
}
}

macro_rules! deref {
($x:expr) => {
{
let exit_block = fn_builder.create_block();
fn_builder.append_block_param(exit_block, types::I64);
let loop_block = fn_builder.create_block();
fn_builder.append_block_param(loop_block, types::I64);
fn_builder.ins().jump(loop_block, &[$x]);
fn_builder.switch_to_block(loop_block);
let addr = fn_builder.block_params(loop_block)[0];
let value = store!(addr);
// check if is var
let tag = fn_builder.ins().band_imm(value, 64);
let is_var = fn_builder.ins().icmp_imm(IntCC::Equal, tag, HeapCellValueTag::Var as i64);
let not_equal = fn_builder.ins().icmp(IntCC::NotEqual, value, addr);
let check = fn_builder.ins().band(is_var, not_equal);
fn_builder.ins().brif(check, loop_block, &[value], exit_block, &[value]);
// exit
fn_builder.seal_block(loop_block);
fn_builder.seal_block(exit_block);
fn_builder.switch_to_block(exit_block);
fn_builder.block_params(exit_block)[0]

}
}
}

for wam_instr in code {
match wam_instr {
// TODO Missing RegType Perm
Instruction::PutStructure(name, arity, reg) => {
let atom_cell = atom_as_cell!(name, arity);
let atom = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(atom_cell.into_bytes()));
let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone());
let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64);
fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, atom]);
let str_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(str_loc_as_cell!(0).into_bytes()));
let heap_len = heap_len!();
let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8);
let str_cell = fn_builder.ins().bor(heap_len_shift, str_cell);
match reg {
RegType::Temp(x) => {
registers[x] = str_cell;
}
_ => unimplemented!()
}
}
// TODO Missing RegType Perm
Instruction::SetVariable(reg) => {
let heap_loc_cell = heap_loc_as_cell!(0);
let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes()));
let heap_len = heap_len!();
let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8);
let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell);
let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone());
let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64);
fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]);
match reg {
RegType::Temp(x) => {
registers[x] = heap_loc_cell;
}
_ => unimplemented!()
}
}
// TODO: Missing RegType Perm
Instruction::SetValue(reg) => {
let value = match reg {
RegType::Temp(x) => {
registers[x]
},
_ => unimplemented!()
};
let value = store!(value);

let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone());
let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64);
fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, value]);
}
// TODO: Missing RegType Perm. Let's suppose Mode is local to each predicate
// TODO: Missing support for PStr and CStr
Instruction::UnifyVariable(reg) => {
let read_block = fn_builder.create_block();
let write_block = fn_builder.create_block();
let exit_block = fn_builder.create_block();
let mode_value = fn_builder.use_var(mode);
fn_builder.ins().brif(mode_value, write_block, &[], read_block, &[]);
fn_builder.seal_block(read_block);
fn_builder.seal_block(write_block);
// read
fn_builder.switch_to_block(read_block);
let heap_ptr = heap_as_ptr!();
let s_value = fn_builder.use_var(s);
let idx = fn_builder.ins().iadd(heap_ptr, s_value);
let value = fn_builder.ins().load(types::I64, MemFlags::trusted(), idx, Offset32::new(0));
let value = deref!(value);
match reg {
RegType::Temp(x) => {
registers[x] = value;
},
_ => unimplemented!()
}
let sum_s = fn_builder.ins().iadd_imm(s_value, 8);
fn_builder.def_var(s, sum_s);
fn_builder.ins().jump(exit_block, &[]);
// write (equal to SetVariable)
fn_builder.switch_to_block(write_block);
let heap_loc_cell = heap_loc_as_cell!(0);
let heap_loc_cell = fn_builder.ins().iconst(types::I64, i64::from_le_bytes(heap_loc_cell.into_bytes()));
let heap_len = heap_len!();
let heap_len_shift = fn_builder.ins().ishl_imm(heap_len, 8);
let heap_loc_cell = fn_builder.ins().bor(heap_len_shift, heap_loc_cell);
let sig_ref = fn_builder.import_signature(self.heap_push_sig.clone());
let heap_push_fn = fn_builder.ins().iconst(types::I64, self.heap_push as i64);
fn_builder.ins().call_indirect(sig_ref, heap_push_fn, &[heap, heap_loc_cell]);
match reg {
RegType::Temp(x) => {
registers[x] = heap_loc_cell;
}
_ => unimplemented!()
}
fn_builder.ins().jump(exit_block, &[]);
// exit
fn_builder.switch_to_block(exit_block);
fn_builder.seal_block(exit_block);

}
Instruction::Proceed => {
fn_builder.ins().return_(&[]);
fn_builder.ins().return_(&registers);
break;
},
_ => {
Expand All @@ -118,8 +347,6 @@ impl JitMachine {
fn_builder.seal_all_blocks();
fn_builder.finalize();

let mut sig = self.module.make_signature();
sig.call_conv = isa::CallConv::Tail;

let func = self.module.declare_function(name, Linkage::Local, &sig).unwrap();
self.module.define_function(func, &mut self.ctx).unwrap();
Expand All @@ -129,6 +356,6 @@ impl JitMachine {
}

pub fn exec(&self, name: &str, machine_st: &mut MachineState) -> Result<(), ()> {
Ok(())
Err(())
}
}
4 changes: 2 additions & 2 deletions src/machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub mod dispatch;
pub mod gc;
pub mod heap;
#[cfg(feature = "jit")]
pub mod jit;
pub mod jit2;
pub mod lib_machine;
pub mod load_state;
pub mod machine_errors;
Expand Down Expand Up @@ -42,7 +42,7 @@ use crate::machine::compile::*;
use crate::machine::copier::*;
use crate::machine::heap::*;
#[cfg(feature = "jit")]
use crate::machine::jit::*;
use crate::machine::jit2::*;
use crate::machine::loader::*;
use crate::machine::machine_errors::*;
use crate::machine::machine_indices::*;
Expand Down
4 changes: 2 additions & 2 deletions src/machine/system_calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::machine::code_walker::*;
use crate::machine::copier::*;
use crate::machine::heap::*;
#[cfg(feature = "jit")]
use crate::machine::jit::*;
use crate::machine::jit2::*;
use crate::machine::machine_errors::*;
use crate::machine::machine_indices::*;
use crate::machine::machine_state::*;
Expand Down Expand Up @@ -5038,7 +5038,7 @@ impl Machine {
let mut code = vec![];
walk_code(&self.code, first_idx, |instr| code.push(instr.clone()));

match self.jit_machine.compile(&format!("{}/{}", name.as_str(), arity), code) {
match self.jit_machine.compile(&name.as_str(), arity, code) {
Err(JitCompileError::UndefinedPredicate) => {
eprintln!("jit_compiler: undefined_predicate");
self.machine_st.fail = true;
Expand Down

0 comments on commit 3dc92f2

Please sign in to comment.