From 4a6e350b4389f7903c186560596da4d68ff54cee Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Fri, 10 Jan 2025 13:44:10 +0000 Subject: [PATCH] lint --- src/cli.rs | 4 +- src/lib.rs | 9 +- src/py.rs | 8 +- src/qir.rs | 318 +++++++++++++++++++++++++++++++----------------- src/rotation.rs | 2 +- 5 files changed, 219 insertions(+), 122 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 132c6db..6850316 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -8,7 +8,7 @@ use hugr::llvm::inkwell; use hugr_cli::HugrArgs; use itertools::Itertools; -use crate::{CompileArgs}; +use crate::CompileArgs; /// Main command line interface #[derive(Parser, Debug)] @@ -104,7 +104,7 @@ impl Cli { debug: self.debug, save_hugr: self.save_hugr.clone(), verbosity: self.hugr_args.verbose.log_level(), - validate: self.validate + validate: self.validate, } } } diff --git a/src/lib.rs b/src/lib.rs index e20dfac..b5b0599 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,10 @@ -use std::fs::{self, OpenOptions}; +use std::fs::{OpenOptions}; use std::rc::Rc; use anyhow::Result; use clap_verbosity_flag::log::Level; -use hugr::algorithms::validation::ValidationLevel; use hugr::llvm::custom::CodegenExtsMap; use hugr::llvm::emit::{EmitHugr, Namer}; -use hugr::llvm::extension::DefaultPreludeCodegen; use hugr::llvm::utils::fat::FatExt; use hugr::llvm::{inkwell, CodegenExtsBuilder}; use hugr::Hugr; @@ -14,13 +12,12 @@ use inkwell::context::Context; use inkwell::module::Module; use qir::{QirCodegenExtension, QirPreludeCodegen}; use rotation::RotationCodegenExtension; -use tket2_hseries::QSystemPass; pub mod cli; pub mod qir; // TODO this was copy pasted, ideally it would live in tket2-hseries -pub mod rotation; mod py; +pub mod rotation; #[non_exhaustive] pub struct CompileArgs { @@ -31,7 +28,6 @@ pub struct CompileArgs { pub validate: bool, } - impl CompileArgs { pub fn codegen_extensions(&self) -> CodegenExtsMap<'static, Hugr> { // TODO: we probably need to customise prelude codegen @@ -55,7 +51,6 @@ impl CompileArgs { "hugr-qir" } - /// TODO: hugr: &mut impl HugrMut pub fn hugr_to_hugr(&self, hugr: &mut Hugr) -> Result<()> { // note this rebases into tket2.qsystem extension diff --git a/src/py.rs b/src/py.rs index ae6ef93..d25735a 100644 --- a/src/py.rs +++ b/src/py.rs @@ -5,8 +5,8 @@ use hugr::llvm::inkwell; use itertools::Itertools as _; use pyo3::{ pyfunction, pymodule, - types::{PyAnyMethods as _, PyDict, PyModule, PyModuleMethods as _, PyTuple}, - wrap_pyfunction, Bound, PyAny, PyResult, Python, + types::{PyAnyMethods as _, PyModule, PyModuleMethods as _, PyTuple}, + wrap_pyfunction, Bound, PyResult, }; use crate::cli::Cli; @@ -14,7 +14,9 @@ use crate::cli::Cli; #[pyfunction] #[pyo3(signature = (*args))] pub fn cli(args: &Bound) -> PyResult<()> { - let args = iter::once("hugr-qir".into()).chain(args.extract::>()?).collect_vec(); + let args = iter::once("hugr-qir".into()) + .chain(args.extract::>()?) + .collect_vec(); let context = inkwell::context::Context::create(); Cli::try_parse_from(args) .map_err(anyhow::Error::from)? diff --git a/src/qir.rs b/src/qir.rs index abb4b9f..e189df1 100644 --- a/src/qir.rs +++ b/src/qir.rs @@ -1,33 +1,60 @@ -use hugr::{extension::{prelude::{qb_t, ConstString}, simple_op::MakeExtensionOp as _}, llvm::{extension::PreludeCodegen, CodegenExtension, CodegenExtsBuilder}, ops::{ExtensionOp, Value}, HugrView}; use anyhow::{anyhow, bail, Result}; use hugr::llvm as hugr_llvm; +use hugr::{ + extension::{ + prelude::{qb_t, ConstString}, + simple_op::MakeExtensionOp as _, + }, + llvm::{extension::PreludeCodegen, CodegenExtension, CodegenExtsBuilder}, + ops::{ExtensionOp, Value}, + HugrView, +}; use hugr_llvm::inkwell; -use inkwell::{context::Context, types::{BasicMetadataTypeEnum, BasicType, FloatType, PointerType}, values::BasicMetadataValueEnum}; +use inkwell::{ + context::Context, + types::{BasicMetadataTypeEnum, BasicType, FloatType}, + values::BasicMetadataValueEnum, +}; use itertools::Itertools; use tket2::extension::rotation::rotation_type; use tket2_hseries::extension::result::{ResultOp, ResultOpDef}; -use hugr_llvm::{emit::{emit_value, EmitFuncContext, EmitOpArgs}, sum::LLVMSumValue, types::{HugrSumType, TypingSession}}; +use hugr_llvm::{ + emit::{emit_value, EmitFuncContext, EmitOpArgs}, + sum::LLVMSumValue, + types::{HugrSumType, TypingSession}, +}; - -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct QirPreludeCodegen; impl PreludeCodegen for QirPreludeCodegen { - fn qubit_type<'c,'d>(&self, session: &TypingSession<'c,'d>) -> impl BasicType<'c> { + fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> { let iw_ctx = session.iw_context(); - iw_ctx.get_struct_type("QUBIT").unwrap_or_else(|| iw_ctx.opaque_struct_type("QUBIT")).ptr_type(Default::default()) + iw_ctx + .get_struct_type("QUBIT") + .unwrap_or_else(|| iw_ctx.opaque_struct_type("QUBIT")) + .ptr_type(Default::default()) } } -fn result_type<'c>(ctx: &'c Context) -> impl BasicType<'c> { - ctx.get_struct_type("RESULT").unwrap_or_else(|| ctx.opaque_struct_type("RESULT")).ptr_type(Default::default()) +fn result_type(ctx: &Context) -> impl BasicType<'_> { + ctx.get_struct_type("RESULT") + .unwrap_or_else(|| ctx.opaque_struct_type("RESULT")) + .ptr_type(Default::default()) } -fn emit_qir_1f_xqb<'c,'d,H: HugrView>(context: &mut EmitFuncContext<'c,'d,H>, args: EmitOpArgs<'c, '_, ExtensionOp, H>, func: impl AsRef) -> Result<()> { +fn emit_qir_1f_xqb<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + args: EmitOpArgs<'c, '_, ExtensionOp, H>, + func: impl AsRef, +) -> Result<()> { let iw_ctx = context.iw_context(); let qb_t = context.llvm_type(&qb_t())?; - let half_turns_t: FloatType = context.llvm_type(&rotation_type())?.try_into().map_err(|_| anyhow!("hugr type 'rotation' does not map to an LLVM float type"))?; + let half_turns_t: FloatType = context + .llvm_type(&rotation_type())? + .try_into() + .map_err(|_| anyhow!("hugr type 'rotation' does not map to an LLVM float type"))?; let args_tys = { let mut x = vec![BasicMetadataTypeEnum::from(half_turns_t)]; x.extend((0..args.inputs.len() - 1).map(|_| BasicMetadataTypeEnum::from(qb_t))); @@ -36,21 +63,36 @@ fn emit_qir_1f_xqb<'c,'d,H: HugrView>(context: &mut EmitFuncContext<'c,'d,H>, ar let func_ty = iw_ctx.void_type().fn_type(&args_tys, false); let func = context.get_extern_func(func, func_ty)?; - let qb_inputs = args.inputs.iter().copied().take(args.inputs.len() - 1).collect_vec(); + let qb_inputs = args + .inputs + .iter() + .copied() + .take(args.inputs.len() - 1) + .collect_vec(); let func_inputs = { let mut x = vec![args.inputs.last().copied().unwrap().into()]; - x.extend(qb_inputs.iter().copied().map_into::()); + x.extend( + qb_inputs + .iter() + .copied() + .map_into::(), + ); x }; context.builder().build_call(func, &func_inputs, "")?; args.outputs.finish(context.builder(), qb_inputs) - } -fn emit_qir_xqb<'c,'d,H: HugrView>(context: &mut EmitFuncContext<'c,'d,H>, args: EmitOpArgs<'c, '_, ExtensionOp, H>, func: impl AsRef) -> Result<()> { +fn emit_qir_xqb<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + args: EmitOpArgs<'c, '_, ExtensionOp, H>, + func: impl AsRef, +) -> Result<()> { let iw_ctx = context.iw_context(); let qb_t = context.llvm_type(&qb_t())?; - let func_ty = iw_ctx.void_type().fn_type(&vec![qb_t.into();args.inputs.len()], false); + let func_ty = iw_ctx + .void_type() + .fn_type(&vec![qb_t.into(); args.inputs.len()], false); let func = context.get_extern_func(func, func_ty)?; let func_inputs = args.inputs.iter().copied().map_into().collect_vec(); @@ -66,9 +108,10 @@ impl CodegenExtension for QirCodegenExtension { builder: CodegenExtsBuilder<'a, H>, ) -> CodegenExtsBuilder<'a, H> where - Self: 'a { - builder.simple_extension_op::(|context, args, op| { - match op { + Self: 'a, + { + builder + .simple_extension_op::(|context, args, op| match op { tket2::Tk2Op::H => emit_qir_xqb(context, args, "__quantum__qis__h__body"), tket2::Tk2Op::CX => emit_qir_xqb(context, args, "__quantum__qis__cx__body"), tket2::Tk2Op::CY => emit_qir_xqb(context, args, "__quantum__qis__cy__body"), @@ -89,14 +132,29 @@ impl CodegenExtension for QirCodegenExtension { let iw_ctx = context.iw_context(); let qb_t = qb.get_type(); let res_t = result_type(iw_ctx); - let measure_t = res_t.fn_type(&vec![qb_t.into()], false); - let measure_func = context.get_extern_func("__quantum__qis__m__body", measure_t)?; + let measure_t = res_t.fn_type(&[qb_t.into()], false); + let measure_func = + context.get_extern_func("__quantum__qis__m__body", measure_t)?; - let read_result_t = iw_ctx.bool_type().fn_type(&[res_t.as_basic_type_enum().into()], false); - let read_result_func = context.get_extern_func("__quantum__qis__read_result__body", read_result_t)?; + let read_result_t = iw_ctx + .bool_type() + .fn_type(&[res_t.as_basic_type_enum().into()], false); + let read_result_func = context + .get_extern_func("__quantum__qis__read_result__body", read_result_t)?; - let result = context.builder().build_call(measure_func, &[qb.into()], "")?.try_as_basic_value().left().ok_or_else(|| anyhow!("expected a result from measure"))?; - let result_i1 = context.builder().build_call(read_result_func, &[result.into()], "")?.try_as_basic_value().left().ok_or_else(|| anyhow!("expected a bool from read_result"))?.into_int_value(); + let result = context + .builder() + .build_call(measure_func, &[qb.into()], "")? + .try_as_basic_value() + .left() + .ok_or_else(|| anyhow!("expected a result from measure"))?; + let result_i1 = context + .builder() + .build_call(read_result_func, &[result.into()], "")? + .try_as_basic_value() + .left() + .ok_or_else(|| anyhow!("expected a bool from read_result"))? + .into_int_value(); let true_val = emit_value(context, &Value::true_val())?; let false_val = emit_value(context, &Value::false_val())?; @@ -104,26 +162,42 @@ impl CodegenExtension for QirCodegenExtension { .builder() .build_select(result_i1, true_val, false_val, "")?; args.outputs.finish(context.builder(), [qb, res]) - }, + } tket2::Tk2Op::MeasureFree => { let qb = args.inputs[0]; let iw_ctx = context.iw_context(); let qb_t = qb.get_type(); let res_t = result_type(iw_ctx); - let measure_t = res_t.fn_type(&vec![qb_t.into()], false); - let measure_func = context.get_extern_func("__quantum__qis__m__body", measure_t)?; + let measure_t = res_t.fn_type(&[qb_t.into()], false); + let measure_func = + context.get_extern_func("__quantum__qis__m__body", measure_t)?; - let read_result_t = iw_ctx.bool_type().fn_type(&[res_t.as_basic_type_enum().into()], false); - let read_result_func = context.get_extern_func("__quantum__qis__read_result__body", read_result_t)?; + let read_result_t = iw_ctx + .bool_type() + .fn_type(&[res_t.as_basic_type_enum().into()], false); + let read_result_func = context + .get_extern_func("__quantum__qis__read_result__body", read_result_t)?; - let result = context.builder().build_call(measure_func, &[qb.into()], "")?.try_as_basic_value().left().ok_or_else(|| anyhow!("expected a result from measure"))?; - let result_i1 = context.builder().build_call(read_result_func, &[result.into()], "")?.try_as_basic_value().left().ok_or_else(|| anyhow!("expected a bool from read_result"))?.into_int_value(); + let result = context + .builder() + .build_call(measure_func, &[qb.into()], "")? + .try_as_basic_value() + .left() + .ok_or_else(|| anyhow!("expected a result from measure"))?; + let result_i1 = context + .builder() + .build_call(read_result_func, &[result.into()], "")? + .try_as_basic_value() + .left() + .ok_or_else(|| anyhow!("expected a bool from read_result"))? + .into_int_value(); let true_val = emit_value(context, &Value::true_val())?; let false_val = emit_value(context, &Value::false_val())?; let qfree_t = iw_ctx.void_type().fn_type(&[qb_t.into()], false); - let qfree_func = context.get_extern_func("__quantum__rt__qubit_release", qfree_t)?; + let qfree_func = + context.get_extern_func("__quantum__rt__qubit_release", qfree_t)?; context.builder().build_call(qfree_func, &[qb.into()], "")?; let res = context @@ -133,92 +207,118 @@ impl CodegenExtension for QirCodegenExtension { } tket2::Tk2Op::QAlloc => { let qb_t = context.llvm_type(&qb_t())?; - let qalloc_t = qb_t.fn_type(&vec![], false); - let qalloc_func = context.get_extern_func("__quantum__rt__qubit_allocate", qalloc_t)?; - let qb = context.builder().build_call(qalloc_func, &[], "")?.try_as_basic_value().left().ok_or_else(|| anyhow!("expected a qubit from qalloc"))?; + let qalloc_t = qb_t.fn_type(&[], false); + let qalloc_func = + context.get_extern_func("__quantum__rt__qubit_allocate", qalloc_t)?; + let qb = context + .builder() + .build_call(qalloc_func, &[], "")? + .try_as_basic_value() + .left() + .ok_or_else(|| anyhow!("expected a qubit from qalloc"))?; args.outputs.finish(context.builder(), [qb]) - }, + } tket2::Tk2Op::QFree => { let iw_ctx = context.iw_context(); let qb = args.inputs[0]; let qb_t = qb.get_type(); let qfree_t = iw_ctx.void_type().fn_type(&[qb_t.into()], false); - let qfree_func = context.get_extern_func("__quantum__rt__qubit_release", qfree_t)?; + let qfree_func = + context.get_extern_func("__quantum__rt__qubit_release", qfree_t)?; context.builder().build_call(qfree_func, &[qb.into()], "")?; args.outputs.finish(context.builder(), []) } - _ => bail!("Unknown op: {op:?}") - } - - }).simple_extension_op::(|context, args, op| { - let result_op = ResultOp::from_extension_op(&args.node())?; - let tag_str = result_op.tag; - if tag_str.is_empty() { - return Err(anyhow!("Empty result tag received")); - } - - let tag_ptr = emit_value(context, &ConstString::new(tag_str).into())?; - let i8_ptr_ty = context.iw_context().i8_type().ptr_type(Default::default()).as_basic_type_enum(); - - match op { - ResultOpDef::Bool => { - let [val] = args - .inputs - .try_into() - .map_err(|_| anyhow!("result_bool expects one input"))?; - let bool_type = context.llvm_sum_type(HugrSumType::new_unary(2))?; - let val = LLVMSumValue::try_new(val, bool_type) - .map_err(|_| anyhow!("bool_type expects a value"))? - .build_get_tag(context.builder())?; - let i1_ty = context.iw_context().bool_type(); - let trunc_val = context.builder().build_int_truncate(val, i1_ty, "")?; - let print_fn_ty = context.iw_context().void_type().fn_type(&[i1_ty.into(), i8_ptr_ty.into(), ], false); - let print_fn = context.get_extern_func("__quantum__rt__bool_record_output", print_fn_ty)?; - context.builder().build_call( - print_fn, - &[trunc_val.into(), tag_ptr.into()], - "print_bool", - )?; - args.outputs.finish(context.builder(), []) - } - ResultOpDef::Int | ResultOpDef::UInt => { - let [val] = args - .inputs - .try_into() - .map_err(|_| anyhow!("result_bool expects one input"))?; - let i64_ty = context.iw_context().i64_type(); - let print_fn_ty = context.iw_context().void_type().fn_type(&[i64_ty.into(), i8_ptr_ty.into(), ], false); - let print_fn = context.get_extern_func("__quantum__rt__int_record_output", print_fn_ty)?; - context.builder().build_call( - print_fn, - &[val.into(), tag_ptr.into()], - "print_bool", - )?; - args.outputs.finish(context.builder(), []) - } - ResultOpDef::F64 => { - let [val] = args - .inputs - .try_into() - .map_err(|_| anyhow!("result_bool expects one input"))?; - let f64_ty = context.iw_context().f64_type(); - let print_fn_ty = context.iw_context().void_type().fn_type(&[f64_ty.into(), i8_ptr_ty.into(), ], false); - let print_fn = context.get_extern_func("__quantum__rt__double_record_output", print_fn_ty)?; - context.builder().build_call( - print_fn, - &[val.into(), tag_ptr.into()], - "print_bool", - )?; - args.outputs.finish(context.builder(), []) - } - ResultOpDef::ArrBool => todo!(), - ResultOpDef::ArrInt => todo!(), - ResultOpDef::ArrUInt => todo!(), - ResultOpDef::ArrF64 => todo!(), - _ => todo!(), - } + _ => bail!("Unknown op: {op:?}"), + }) + .simple_extension_op::( + |context, args, op| { + let result_op = ResultOp::from_extension_op(&args.node())?; + let tag_str = result_op.tag; + if tag_str.is_empty() { + return Err(anyhow!("Empty result tag received")); + } - }) + let tag_ptr = emit_value(context, &ConstString::new(tag_str).into())?; + let i8_ptr_ty = context + .iw_context() + .i8_type() + .ptr_type(Default::default()) + .as_basic_type_enum(); + match op { + ResultOpDef::Bool => { + let [val] = args + .inputs + .try_into() + .map_err(|_| anyhow!("result_bool expects one input"))?; + let bool_type = context.llvm_sum_type(HugrSumType::new_unary(2))?; + let val = LLVMSumValue::try_new(val, bool_type) + .map_err(|_| anyhow!("bool_type expects a value"))? + .build_get_tag(context.builder())?; + let i1_ty = context.iw_context().bool_type(); + let trunc_val = context.builder().build_int_truncate(val, i1_ty, "")?; + let print_fn_ty = context + .iw_context() + .void_type() + .fn_type(&[i1_ty.into(), i8_ptr_ty.into()], false); + let print_fn = context.get_extern_func( + "__quantum__rt__bool_record_output", + print_fn_ty, + )?; + context.builder().build_call( + print_fn, + &[trunc_val.into(), tag_ptr.into()], + "print_bool", + )?; + args.outputs.finish(context.builder(), []) + } + ResultOpDef::Int | ResultOpDef::UInt => { + let [val] = args + .inputs + .try_into() + .map_err(|_| anyhow!("result_bool expects one input"))?; + let i64_ty = context.iw_context().i64_type(); + let print_fn_ty = context + .iw_context() + .void_type() + .fn_type(&[i64_ty.into(), i8_ptr_ty.into()], false); + let print_fn = context + .get_extern_func("__quantum__rt__int_record_output", print_fn_ty)?; + context.builder().build_call( + print_fn, + &[val.into(), tag_ptr.into()], + "print_bool", + )?; + args.outputs.finish(context.builder(), []) + } + ResultOpDef::F64 => { + let [val] = args + .inputs + .try_into() + .map_err(|_| anyhow!("result_bool expects one input"))?; + let f64_ty = context.iw_context().f64_type(); + let print_fn_ty = context + .iw_context() + .void_type() + .fn_type(&[f64_ty.into(), i8_ptr_ty.into()], false); + let print_fn = context.get_extern_func( + "__quantum__rt__double_record_output", + print_fn_ty, + )?; + context.builder().build_call( + print_fn, + &[val.into(), tag_ptr.into()], + "print_bool", + )?; + args.outputs.finish(context.builder(), []) + } + ResultOpDef::ArrBool => todo!(), + ResultOpDef::ArrInt => todo!(), + ResultOpDef::ArrUInt => todo!(), + ResultOpDef::ArrF64 => todo!(), + _ => todo!(), + } + }, + ) } } diff --git a/src/rotation.rs b/src/rotation.rs index fa5c0b8..107801b 100644 --- a/src/rotation.rs +++ b/src/rotation.rs @@ -1,8 +1,8 @@ use anyhow::{anyhow, bail, Result}; -use hugr::llvm::inkwell; use hugr::llvm::custom::CodegenExtsBuilder; use hugr::llvm::emit::{emit_value, EmitFuncContext, EmitOpArgs}; use hugr::llvm::extension::{DefaultPreludeCodegen, PreludeCodegen}; +use hugr::llvm::inkwell; use hugr::llvm::types::TypingSession; use hugr::llvm::CodegenExtension; use inkwell::types::FloatType;