From d239a644a1e0867a6d9ffb82012da62dc38456a8 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Sun, 20 Oct 2024 11:42:35 +0100 Subject: [PATCH] test: Add tests for exact float <-> int roundtrips --- src/extension/conversions.rs | 112 ++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/src/extension/conversions.rs b/src/extension/conversions.rs index 1dde820..326d419 100644 --- a/src/extension/conversions.rs +++ b/src/extension/conversions.rs @@ -1,17 +1,21 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, ensure, Result}; use hugr::{ extension::{ prelude::{sum_with_error, ConstError, BOOL_T}, simple_op::MakeExtensionOp, }, - ops::{constant::Value, custom::ExtensionOp}, + ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _}, std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, - types::{TypeArg, TypeEnum}, + types::{TypeArg, TypeEnum, TypeRow}, HugrView, }; -use inkwell::{values::BasicValue, FloatPredicate, IntPredicate}; +use inkwell::{ + types::IntType, + values::BasicValue, + FloatPredicate, IntPredicate, +}; use crate::{ custom::{CodegenExtension, CodegenExtsBuilder}, @@ -21,6 +25,7 @@ use crate::{ EmitOpArgs, }, sum::LLVMSumValue, + types::HugrType, }; fn build_trunc_op<'c, H: HugrView>( @@ -29,38 +34,45 @@ fn build_trunc_op<'c, H: HugrView>( log_width: u64, args: EmitOpArgs<'c, '_, ExtensionOp, H>, ) -> Result<()> { - // Note: This logic is copied from `llvm_type` in the IntTypes - // extension. We need to have a common source of truth for this. - let (width, (int_min_value_s, int_max_value_s), int_max_value_u) = match log_width { - 0..=3 => (8, (i8::MIN as i64, i8::MAX as i64), u8::MAX as u64), - 4 => (16, (i16::MIN as i64, i16::MAX as i64), u16::MAX as u64), - 5 => (32, (i32::MIN as i64, i32::MAX as i64), u32::MAX as u64), - 6 => (64, (i64::MIN, i64::MAX), u64::MAX), - m => return Err(anyhow!("ConversionEmitter: unsupported log_width: {}", m)), - }; - let hugr_int_ty = INT_TYPES[log_width as usize].clone(); - let int_ty = context - .typing_session() - .llvm_type(&hugr_int_ty)? - .into_int_type(); + let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]); + // TODO: it would be nice to get this info out of `ops.node()`, this would + // require adding appropriate methods to `ConvertOpDef`. In the meantime, we + // assert that the output types are as we expect. + debug_assert_eq!( + TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]), + args.node().signature().output + ); + + let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else { + bail!("Expected `arithmetic.int` to lower to an llvm integer") + }; - let hugr_sum_ty = sum_with_error(vec![hugr_int_ty]); - let sum_ty = context.typing_session().llvm_sum_type(hugr_sum_ty)?; + let sum_ty = context.llvm_sum_type(hugr_sum_ty)?; + + let (width, int_min_value_s, int_max_value_s, int_max_value_u) = { + ensure!( + log_width <= 6, + "Expected log_width of output to be <= 6, found: {log_width}" + ); + let width = 1 << log_width; + ( + width, + i64::MIN >> (64 - width), + i64::MAX >> (64 - width), + u64::MAX >> (64 - width), + ) + }; emit_custom_unary_op(context, args, |ctx, arg, _| { // We have to check if the conversion will work, so we // make the maximum int and convert to a float, then compare // with the function input. - let flt_max = if signed { - ctx.iw_context() - .f64_type() - .const_float(int_max_value_s as f64) + let flt_max = ctx.iw_context().f64_type().const_float(if signed { + int_max_value_s as f64 } else { - ctx.iw_context() - .f64_type() - .const_float(int_max_value_u as f64) - }; + int_max_value_u as f64 + }); let within_upper_bound = ctx.builder().build_float_compare( FloatPredicate::OLT, @@ -69,13 +81,11 @@ fn build_trunc_op<'c, H: HugrView>( "within_upper_bound", )?; - let flt_min = if signed { - ctx.iw_context() - .f64_type() - .const_float(int_min_value_s as f64) + let flt_min = ctx.iw_context().f64_type().const_float(if signed { + int_min_value_s as f64 } else { - ctx.iw_context().f64_type().const_float(0.0) - }; + 0.0 + }); let within_lower_bound = ctx.builder().build_float_compare( FloatPredicate::OLE, @@ -414,26 +424,20 @@ mod test { .outputs_arr(); let [flt] = { let op = if signed { - ConvertOpDef::convert_s.with_log_width(6) + ConvertOpDef::convert_s.with_log_width(6) } else { - ConvertOpDef::convert_u.with_log_width(6) + ConvertOpDef::convert_u.with_log_width(6) }; - builder - .add_dataflow_op(op, [int]) - .unwrap() - .outputs_arr() + builder.add_dataflow_op(op, [int]).unwrap().outputs_arr() }; let [int_or_err] = { let op = if signed { - ConvertOpDef::trunc_s.with_log_width(6) + ConvertOpDef::trunc_s.with_log_width(6) } else { - ConvertOpDef::trunc_u.with_log_width(6) + ConvertOpDef::trunc_u.with_log_width(6) }; - builder - .add_dataflow_op(op, [flt]) - .unwrap() - .outputs_arr() + builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr() }; let sum_ty = sum_with_error(int64.clone()); let variants = (0..sum_ty.num_variants()) @@ -482,12 +486,26 @@ mod test { #[case(4294967295)] #[case(42)] #[case(18_000_000_000_000_000_000)] - fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: u64) { + fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) { add_extensions(&mut exec_ctx); let hugr = roundtrip_hugr(val, false); assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main")); } + #[rstest] + // Exact roundtrip conversion is defined on values up to 2**53 for f64. + #[case(0)] + #[case(3)] + #[case(255)] + #[case(4294967295)] + #[case(42)] + #[case(-9_000_000_000_000_000_000)] + fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) { + add_extensions(&mut exec_ctx); + let hugr = roundtrip_hugr(val as u64, true); + assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64); + } + // For unisgined ints larger than (1 << 54) - 1, f64s do not have enough // precision to exactly roundtrip the int. // The exact behaviour of the round-trip is is platform-dependent.