diff --git a/src/cli.rs b/src/cli.rs index 0a7b149..2b8b840 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,7 +1,7 @@ //! Provides a command line interface to tket2-hseries use std::rc::Rc; -use clap::{Command, FromArgMatches, Parser,ArgMatches, Args}; +use clap::{ArgMatches, Args, Command, FromArgMatches, Parser}; use hugr::std_extensions::arithmetic::{ conversions::EXTENSION as CONVERSIONS_EXTENSION, float_ops::EXTENSION as FLOAT_OPS_EXTENSION, float_types::EXTENSION as FLOAT_TYPES_EXTENSION, int_ops::EXTENSION as INT_OPS_EXTENSION, @@ -40,7 +40,9 @@ pub struct HugrCliCmdLineArgs(hugr_cli::CmdLineArgs); impl FromArgMatches for HugrCliCmdLineArgs { fn from_arg_matches(matches: &ArgMatches) -> Result { - Ok(HugrCliCmdLineArgs(hugr_cli::CmdLineArgs::from_arg_matches(matches)?)) + Ok(HugrCliCmdLineArgs(hugr_cli::CmdLineArgs::from_arg_matches( + matches, + )?)) } fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> { @@ -50,13 +52,11 @@ impl FromArgMatches for HugrCliCmdLineArgs { impl Args for HugrCliCmdLineArgs { fn augment_args(cmd: Command) -> Command { - hugr_cli::CmdLineArgs::augment_args(cmd) - .mut_arg("mermaid", |x| x.hide(true)) + hugr_cli::CmdLineArgs::augment_args(cmd).mut_arg("mermaid", |x| x.hide(true)) } fn augment_args_for_update(cmd: Command) -> Command { - hugr_cli::CmdLineArgs::augment_args_for_update(cmd) - .mut_arg("mermaid", |x| x.hide(true)) + hugr_cli::CmdLineArgs::augment_args_for_update(cmd).mut_arg("mermaid", |x| x.hide(true)) } } @@ -71,7 +71,7 @@ pub struct CmdLineArgs { #[arg(short='p',long,default_value=crate::emit::NAMER_DEFAULT_PREFIX)] mangle_prefix: String, - #[arg(short='s',long, default_value_t=true)] + #[arg(short = 's', long, default_value_t = true)] mangle_node_suffix: bool, } diff --git a/src/custom/int.rs b/src/custom/int.rs index 33af868..2ea2644 100644 --- a/src/custom/int.rs +++ b/src/custom/int.rs @@ -36,7 +36,18 @@ impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for IntOpEmitter<'c, '_, H> { let a = builder.build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?; args.outputs.finish(builder, [a.into()]) } - _ => Err(anyhow!("IntOpEmitter: unknown name")), + "ieq" => { + let builder = self.0.builder(); + let [lhs, rhs] = TryInto::<[_; 2]>::try_into(args.inputs).unwrap(); + let a = builder.build_int_compare( + inkwell::IntPredicate::EQ, + lhs.into_int_value(), + rhs.into_int_value(), + "", + )?; + args.outputs.finish(builder, [a.into()]) + } + n => Err(anyhow!("IntOpEmitter: unknown name: {n}")), } } } diff --git a/src/emit/func/mailbox.rs b/src/emit/func/mailbox.rs index 06bb727..a5c8b2e 100644 --- a/src/emit/func/mailbox.rs +++ b/src/emit/func/mailbox.rs @@ -93,6 +93,10 @@ impl<'c> ValuePromise<'c> { pub struct RowMailBox<'c>(Rc>>, Cow<'static, str>); impl<'c> RowMailBox<'c> { + pub fn new_empty() -> Self { + Self::new(std::iter::empty(), None) + } + pub(super) fn new( mbs: impl IntoIterator>, name: Option, diff --git a/src/emit/ops.rs b/src/emit/ops.rs index ea51b7d..27eac29 100644 --- a/src/emit/ops.rs +++ b/src/emit/ops.rs @@ -1,14 +1,18 @@ -use anyhow::{anyhow, Result}; +use std::collections::HashMap; + +use anyhow::{anyhow, ensure, Result}; use hugr::{ hugr::views::SiblingGraph, ops::{ - Call, Case, Conditional, Const, Input, LoadConstant, MakeTuple, NamedOp, OpTag, OpTrait, - OpType, Output, Tag, UnpackTuple, Value, + Call, Case, Conditional, Const, DataflowBlock, ExitBlock, Input, LoadConstant, MakeTuple, + NamedOp, OpTag, OpTrait, OpType, Output, Tag, UnpackTuple, Value, CFG, }, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, }; -use inkwell::{builder::Builder, types::BasicType, values::BasicValueEnum}; +use inkwell::{ + basic_block::BasicBlock, builder::Builder, types::BasicType, values::BasicValueEnum, +}; use itertools::Itertools; use petgraph::visit::Walker; @@ -16,7 +20,7 @@ use crate::fat::FatExt as _; use crate::{fat::FatNode, types::LLVMSumType}; use super::{ - func::{EmitFuncContext, RowPromise}, + func::{EmitFuncContext, RowMailBox, RowPromise}, EmitOp, EmitOpArgs, }; @@ -38,8 +42,12 @@ impl<'c, 'd, H: HugrView> SumOpEmitter<'c, 'd, H> { impl<'c, H: HugrView> EmitOp<'c, MakeTuple, H> for SumOpEmitter<'c, '_, H> { fn emit(&mut self, args: EmitOpArgs<'c, MakeTuple, H>) -> Result<()> { let builder = self.0.builder(); - args.outputs - .finish(builder, [self.1.build_tag(builder, 0, args.inputs)?]) + println!("dougrulz3"); + let r = args + .outputs + .finish(builder, [self.1.build_tag(builder, 0, args.inputs)?])?; + println!("dougrulz4"); + Ok(r) } } @@ -58,13 +66,16 @@ impl<'c, H: HugrView> EmitOp<'c, UnpackTuple, H> for SumOpEmitter<'c, '_, H> { impl<'c, H: HugrView> EmitOp<'c, Tag, H> for SumOpEmitter<'c, '_, H> { fn emit(&mut self, args: EmitOpArgs<'c, Tag, H>) -> Result<()> { + println!("dougrulz5"); let builder = self.0.builder(); - args.outputs.finish( + let r = args.outputs.finish( builder, [self .1 .build_tag(builder, args.node.tag as u32, args.inputs)?], - ) + )?; + println!("dougrulz6"); + Ok(r) } } @@ -228,12 +239,196 @@ impl<'c, H: HugrView> EmitOp<'c, Conditional, H> for ConditionalEmitter<'c, '_, }) .collect::>>()?; - builder.build_switch(tag.into_int_value(), switches[0].1, &switches[1..])?; + builder.build_switch(tag, switches[0].1, &switches[1..])?; builder.position_at_end(exit_block); Ok(()) } } +struct CfgEmitter<'c, 'd, H: HugrView> { + context: &'d mut EmitFuncContext<'c, H>, + bbs: HashMap, (BasicBlock<'c>, RowMailBox<'c>)>, + inputs: Option>>, + outputs: Option>, + node: FatNode<'c, CFG, H>, + entry_node: FatNode<'c, DataflowBlock, H>, + exit_node: FatNode<'c, ExitBlock, H>, +} + +impl<'c, 'd, H: HugrView> CfgEmitter<'c, 'd, H> { + pub fn new( + context: &'d mut EmitFuncContext<'c, H>, + args: EmitOpArgs<'c, CFG, H>, + ) -> Result { + let node = args.node(); + let (inputs, outputs) = (Some(args.inputs), Some(args.outputs)); + let out_types = node.out_value_types().map(|x| x.1).collect_vec(); + let output_row = context.new_row_mail_box(out_types.iter(), "")?; + let exit_block = context.new_basic_block("", None); + let bbs = node + .children() + .map(|child| { + if child.is_exit_block() { + Ok((child, (exit_block, output_row.clone()))) + } else { + let bb = context.new_basic_block("", Some(exit_block)); + let (i, _) = child.get_io().unwrap(); + Ok((child, (bb, context.node_outs_rmb(i)?))) + } + }) + .collect::>>()?; + let [entry_node, exit_node] = node + .children() + .take(2) + .collect_vec() + .try_into() + .map_err(|_| anyhow!("cfg doesn't have two children"))?; + let entry_node = entry_node.try_into_ot::().unwrap(); + let exit_node = exit_node.try_into_ot::().unwrap(); + Ok(CfgEmitter { + context, + bbs, + node, + inputs, + outputs, + entry_node, + exit_node, + }) + } + + fn take_inputs(&mut self) -> Result>> { + self.inputs.take().ok_or(anyhow!("Couldn't take inputs")) + } + + fn take_outputs(&mut self) -> Result> { + self.outputs.take().ok_or(anyhow!("Couldn't take inputs")) + } + + fn get_block_data( + &self, + node: &FatNode<'c, OT, H>, + ) -> Result<&(BasicBlock<'c>, RowMailBox<'c>)> + where + OT: Into + 'c, + { + self.bbs + .get(&node.clone().generalise()) + .ok_or(anyhow!("Couldn't get block data for: {}", node.index())) + } + + fn emit_children(mut self) -> Result<()> { + let inputs = self.take_inputs()?; + let (entry_bb, inputs_rmb) = self.get_block_data(&self.entry_node).cloned()?; + let builder = self.context.builder(); + inputs_rmb.write(builder, inputs)?; + builder.build_unconditional_branch(entry_bb)?; + + for c in self.node.children() { + let (inputs, outputs) = (vec![], RowMailBox::new_empty().promise()); + if let Some(node) = c.try_into_ot::() { + self.emit(EmitOpArgs { + node, + inputs, + outputs, + })?; + } else if let Some(node) = c.try_into_ot::() { + self.emit(EmitOpArgs { + node, + inputs, + outputs, + })?; + } else { + Err(anyhow!("unknown optype: {c}"))?; + } + } + let outputs = self.take_outputs()?; + let (exit_bb, outputs_rmb) = self.get_block_data(&self.exit_node).cloned()?; + let builder = self.context.builder(); + builder.position_at_end(exit_bb); + outputs.finish(builder, outputs_rmb.read_vec(builder, [])?)?; + Ok(()) + } +} + +impl<'c, H: HugrView> EmitOp<'c, DataflowBlock, H> for CfgEmitter<'c, '_, H> { + fn emit( + &mut self, + EmitOpArgs { + node, + inputs: _, + outputs: _, + }: EmitOpArgs<'c, DataflowBlock, H>, + ) -> Result<()> { + let (bb, inputs_rmb) = self.bbs.get(&node.clone().generalise()).unwrap(); + let (_, o) = node.get_io().unwrap(); + let successor_data = node + .output_neighbours() + .map(|succ| self.get_block_data(&succ).map(|x| x.clone())) + .collect::>>()?; + + self.context.build_positioned(*bb, |context| { + let outputs_rmb = context.node_ins_rmb(o)?; + let inputs = inputs_rmb.read_vec(context.builder(), [])?; + emit_dataflow_parent( + context, + EmitOpArgs { + node: node.clone(), + inputs, + outputs: outputs_rmb.promise(), + }, + )?; + + let outputs = outputs_rmb.read_vec(context.builder(), [])?; + let branch_sum_type = SumType::new(node.sum_rows.clone()); + let llvm_sum_type = context.llvm_sum_type(branch_sum_type)?; + let tag_bbs = successor_data + .into_iter() + .enumerate() + .map(|(tag, (target_bb, target_rmb))| { + let bb = context.build_positioned_new_block("", Some(*bb), |context, bb| { + let builder = context.builder(); + let mut vals = + llvm_sum_type.build_untag(builder, tag as u32, outputs[0])?; + vals.extend(&outputs[1..]); + target_rmb.write(builder, vals)?; + builder.build_unconditional_branch(target_bb)?; + Ok::<_, anyhow::Error>(bb) + })?; + Ok(( + llvm_sum_type.get_tag_type().const_int(tag as u64, false), + bb, + )) + }) + .collect::>>()?; + let tag_v = llvm_sum_type.build_get_tag(context.builder(), outputs[0])?; + context + .builder() + .build_switch(tag_v, tag_bbs[0].1, &tag_bbs[1..])?; + Ok(()) + }) + } +} + +impl<'c, H: HugrView> EmitOp<'c, ExitBlock, H> for CfgEmitter<'c, '_, H> { + fn emit( + &mut self, + EmitOpArgs { + node, + inputs: _, + outputs: _, + }: EmitOpArgs<'c, ExitBlock, H>, + ) -> Result<()> { + Ok(()) + } +} + +fn emit_cfg<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, H>, + args: EmitOpArgs<'c, CFG, H>, +) -> Result<()> { + CfgEmitter::new(context, args)?.emit_children() +} + fn get_exactly_one_sum_type(ts: impl IntoIterator) -> Result { let Some(TypeEnum::Sum(sum_type)) = ts .into_iter() @@ -250,32 +445,41 @@ fn emit_value<'c, H: HugrView>( context: &mut EmitFuncContext<'c, H>, v: &Value, ) -> Result> { + println!("emit_value: {v:?}"); match v { Value::Extension { e } => { let exts = context.extensions(); - exts.load_constant(context, e.value()) + let val = exts.load_constant(context, e.value())?; + ensure!(val.get_type() == context.llvm_type(&v.get_type())?); + Ok(val) } Value::Function { .. } => todo!(), Value::Tuple { vs } => { + println!("dougrulz1"); let tys = vs.iter().map(|x| x.get_type()).collect_vec(); let llvm_st = LLVMSumType::try_new(&context.typing_session(), SumType::new([tys]))?; let llvm_vs = vs .iter() .map(|x| emit_value(context, x)) .collect::>>()?; - llvm_st.build_tag(context.builder(), 0, llvm_vs) + let r = llvm_st.build_tag(context.builder(), 0, llvm_vs)?; + println!("dougrulz2"); + Ok(r) } Value::Sum { tag, values, sum_type, } => { + println!("dougrulz7"); let llvm_st = LLVMSumType::try_new(&context.typing_session(), sum_type.clone())?; let vs = values .iter() .map(|x| emit_value(context, x)) .collect::>>()?; - llvm_st.build_tag(context.builder(), *tag as u32, vs) + let r = llvm_st.build_tag(context.builder(), *tag as u32, vs)?; + println!("dougrulz8"); + Ok(r) } } } @@ -381,7 +585,6 @@ fn emit_optype<'c, H: HugrView>( OpType::Tag(ref tag) => emit_tag(context, args.into_ot(tag)), OpType::DFG(_) => emit_dataflow_parent(context, args), - // TODO Test cases OpType::CustomOp(ref co) => { let extensions = context.extensions(); extensions.emit(context, args.into_ot(co)) @@ -390,6 +593,7 @@ fn emit_optype<'c, H: HugrView>( OpType::LoadConstant(ref lc) => emit_load_constant(context, args.into_ot(lc)), OpType::Call(ref cl) => emit_call(context, args.into_ot(cl)), OpType::Conditional(ref co) => emit_conditional(context, args.into_ot(co)), + OpType::CFG(ref cfg) => emit_cfg(context, args.into_ot(cfg)), // OpType::FuncDefn(fd) => self.emit(ot.into_ot(fd), context, inputs, outputs), _ => todo!("Unimplemented OpTypeEmitter: {}", args.node().name()), diff --git a/src/fat.rs b/src/fat.rs index b9971d1..f69c1bb 100644 --- a/src/fat.rs +++ b/src/fat.rs @@ -155,11 +155,20 @@ impl<'c, OT, H: HugrView + ?Sized> FatNode<'c, OT, H> { )) } + pub fn node_outputs(&self) -> impl Iterator + '_ { + self.hugr.node_outputs(self.node) + } + + pub fn output_neighbours(&self) -> impl Iterator> + '_ { + self.hugr + .output_neighbours(self.node) + .map(|n| FatNode::new_optype(self.hugr, n)) + } + /// Create a general `FatNode` from a specific one. pub fn generalise(self) -> FatNode<'c, OpType, H> where - &'c OpType: TryInto<&'c OT>, - OT: 'c, + OT: Into + 'c, { // guaranteed to be valid becasue self is valid FatNode { @@ -271,6 +280,12 @@ impl<'c, OT, H> NodeIndex for FatNode<'c, OT, H> { } } +impl<'c, OT, H> NodeIndex for &FatNode<'c, OT, H> { + fn index(self) -> usize { + self.node.index() + } +} + /// An extension trait for [HugrView] which provides methods that delegate to /// [HugrView] and then return the result in [FatNode] form. See for example /// [FatExt::fat_io]. diff --git a/src/types.rs b/src/types.rs index 229e98c..d922d6d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,7 +8,7 @@ use hugr::types::SumType; use hugr::{types::TypeRow, HugrView}; use inkwell::builder::Builder; use inkwell::types::{self as iw, AnyType, AsTypeRef, IntType}; -use inkwell::values::{BasicValue, BasicValueEnum, StructValue}; +use inkwell::values::{BasicValue, BasicValueEnum, IntValue, StructValue}; use inkwell::{ context::Context, types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum, StructType}, @@ -196,13 +196,15 @@ impl<'c> LLVMSumType<'c> { &self, builder: &Builder<'c>, v: impl BasicValue<'c>, - ) -> Result> { + ) -> Result> { let struct_value: StructValue<'c> = v .as_basic_value_enum() .try_into() .map_err(|_| anyhow!("Not a struct type"))?; if self.has_tag_field() { - Ok(builder.build_extract_value(struct_value, 0, "")?) + Ok(builder + .build_extract_value(struct_value, 0, "")? + .into_int_value()) } else { Ok(self.get_tag_type().const_int(0, false).into()) } @@ -244,7 +246,7 @@ impl<'c> LLVMSumType<'c> { ) -> Result> { let expected_num_fields = self.num_fields(tag)?; if expected_num_fields != vs.len() { - Err(anyhow!("LLVMSumType::build: wrong number of fields: expected: {expected_num_fields} actual: {}", vs.len()))? + Err(anyhow!("LLVMSumType::build_tag: wrong number of fields: expected: {expected_num_fields} actual: {} sumtype: {} llvm_type {} vs: {:?}", vs.len(), &self.1, &self.0, vs.clone()))? } let variant_index = self.get_variant_index(tag); let row_t = self diff --git a/tests/guppy.rs b/tests/guppy.rs index eee3277..21ddb8e 100644 --- a/tests/guppy.rs +++ b/tests/guppy.rs @@ -1,9 +1,13 @@ -use std::{env, fs::{read_to_string, File}, path::{Path, PathBuf}, process::Command}; +use std::{ + env, + fs::File, + path::{Path, PathBuf}, + process::Command, +}; -use insta::assert_snapshot; -use rstest::{fixture, rstest}; use insta_cmd::assert_cmd_snapshot; -use tempfile::{tempfile, NamedTempFile}; +use rstest::{fixture, rstest}; +use tempfile::NamedTempFile; struct TestConfig { python_bin: PathBuf, @@ -19,11 +23,13 @@ impl TestConfig { .or_else(|| pathsearch::find_executable_in_path("python")) .unwrap_or_else(|| panic!("Could not find python in PATH or HUGR_LLVM_PYTHON_BIN")); let hugr_llvm_bin = env!("CARGO_BIN_EXE_hugr-llvm").into(); - let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/guppy_test_cases").into(); + let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests/guppy_test_cases") + .into(); TestConfig { python_bin, hugr_llvm_bin, - test_dir + test_dir, } } } @@ -35,8 +41,13 @@ impl TestConfig { .arg(self.test_dir.join(path.as_ref())) .arg("--mermaid") .stdout(file.reopen().unwrap()) - .status().unwrap(); - assert!(status.success(), "Failed to run guppy test case: {:?}", path.as_ref()); + .status() + .unwrap(); + assert!( + status.success(), + "Failed to run guppy test case: {:?}", + path.as_ref() + ); file }