From c73f36ba5e18aa99da65957bd3a4b7babd9eda49 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 23 Oct 2024 15:43:07 +0100 Subject: [PATCH] fix: don't normalise half turns (#137) Closes #136 drive-by: cargo update (hugr 0.13.2) I have tested the test in https://github.com/CQCL-DEV/guppy-integration/pull/19 now passes and the example has expected behaviour with crz oracle. --- Cargo.lock | 63 ++++++++++--------- src/extension/rotation.rs | 40 +++--------- ...__rotation__test__emit_all_ops@llvm14.snap | 57 +++++++---------- ...test__emit_all_ops@pre-mem2reg@llvm14.snap | 63 ++++++++----------- 4 files changed, 92 insertions(+), 131 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e73eef6..70d27db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,9 +77,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "approx" @@ -153,9 +153,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.1.30" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "jobserver", "libc", @@ -644,9 +644,9 @@ dependencies = [ [[package]] name = "hugr" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1591e9f21f7c32c1b6f25bc85239d607fe86d39d8913dd5fa24c26c8bb2859" +checksum = "a4d3b355ab819143e7dd85bda162f1a496fbbf8485a4bb1ff55f5dd52b031fca" dependencies = [ "hugr-core", "hugr-passes", @@ -654,14 +654,15 @@ dependencies = [ [[package]] name = "hugr-cli" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "510dc98333646f9a9cecd9443055a83b9f09cdd72a9cb0a8fab2b740e304d478" +checksum = "3407f8549502a9b8905db5b1ec5dd70bfc05b072313a3f2381ec3414f450f80d" dependencies = [ "clap", "clap-verbosity-flag", "clio", - "hugr-core", + "derive_more", + "hugr", "serde", "serde_json", "thiserror", @@ -669,9 +670,9 @@ dependencies = [ [[package]] name = "hugr-core" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3c692210006389030fe0cab753c0982e62d0d45b0ee366b748d8c568593f639" +checksum = "a86a0796418a62b4cd83c3776021ec46671142fd46b5d355e2078db4c8d286c3" dependencies = [ "bitvec", "bumpalo", @@ -726,9 +727,9 @@ dependencies = [ [[package]] name = "hugr-passes" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "050a23a7f4226712bcefa3d115c102647ed9e95709b284eecc115233dc9e6bfc" +checksum = "f8a375f48c93f2839faad0a2ab6a58c4c88ad656d70398dbd6b45e6371362348" dependencies = [ "hugr-core", "itertools 0.13.0", @@ -881,9 +882,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "linked-hash-map" @@ -1085,9 +1086,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -1245,18 +1246,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", @@ -1265,9 +1266,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -1347,9 +1348,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", @@ -1377,18 +1378,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", @@ -1545,9 +1546,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "serde", ] diff --git a/src/extension/rotation.rs b/src/extension/rotation.rs index 1cc72ee..8031f0b 100644 --- a/src/extension/rotation.rs +++ b/src/extension/rotation.rs @@ -14,7 +14,7 @@ use lazy_static::lazy_static; use crate::{ custom::CodegenExtsBuilder, - emit::{emit_value, get_intrinsic, EmitFuncContext, EmitOpArgs}, + emit::{emit_value, EmitFuncContext, EmitOpArgs}, types::TypingSession, CodegenExtension, }; @@ -138,9 +138,7 @@ impl RotationCodegenExtension { op: RotationOp, ) -> Result<()> { let ts = context.typing_session(); - let module = context.get_current_module(); let builder = context.builder(); - let angle_ty = llvm_angle_type(&ts); match op { RotationOp::radd => { @@ -204,27 +202,7 @@ impl RotationCodegenExtension { .map_err(|_| anyhow!("RotationOp::tohalfturns expects one argument"))?; let half_turns = half_turns.into_float_value(); - // normalised_half_turns is in the interval 0..2 - let normalised_half_turns = { - // normalised_rads = (half_turns/2 - floor(half_turns/2)) * 2 - // note that floor(x) gives the largest integral value less - // than or equal to x so this deals with both positive and - // negative rads. - let turns = - builder.build_float_div(half_turns, angle_ty.const_float(2.0), "")?; - let floor_turns = { - let floor = get_intrinsic(module, "llvm.floor", [angle_ty.into()])?; - builder - .build_call(floor, &[turns.into()], "")? - .try_as_basic_value() - .left() - .ok_or(anyhow!("llvm.floor has no return value"))? - .into_float_value() - }; - let normalised_turns = builder.build_float_sub(turns, floor_turns, "")?; - builder.build_float_mul(normalised_turns, angle_ty.const_float(2.0), "")? - }; - args.outputs.finish(builder, [normalised_half_turns.into()]) + args.outputs.finish(builder, [half_turns.into()]) } op => bail!("Unsupported op: {op:?}"), } @@ -314,7 +292,7 @@ mod test { #[rstest] #[case(ConstRotation::new(1.0).unwrap(), ConstRotation::new(0.5).unwrap(), 1.5)] - #[case(ConstRotation::PI, ConstRotation::new(1.5).unwrap(), 0.5)] + #[case(ConstRotation::PI, ConstRotation::new(1.5).unwrap(), 2.5)] fn exec_aadd( mut exec_ctx: TestContext, #[case] angle1: ConstRotation, @@ -350,7 +328,7 @@ mod test { #[rstest] #[case(ConstRotation::PI, 1.0)] - #[case(ConstRotation::TAU, 0.0)] + #[case(ConstRotation::TAU, 2.0)] #[case(ConstRotation::PI_2, 0.5)] #[case(ConstRotation::PI_4, 0.25)] fn exec_to_halfturns( @@ -420,13 +398,13 @@ mod test { #[rstest] #[case(1.0, Some(1.0))] - #[case(-1.0, Some(1.0))] + #[case(-1.0, Some (-1.0))] #[case(0.5, Some(0.5))] - #[case(-0.5, Some(1.5))] + #[case(-0.5, Some (-0.5))] #[case(0.25, Some(0.25))] - #[case(-0.25, Some(1.75))] - #[case(13.5, Some(1.5))] - #[case(-13.5, Some(0.5))] + #[case(-0.25, Some (-0.25))] + #[case(13.5, Some(13.5))] + #[case(-13.5, Some (-13.5))] #[case(f64::NAN, None)] #[case(f64::INFINITY, None)] #[case(f64::NEG_INFINITY, None)] diff --git a/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@llvm14.snap b/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@llvm14.snap index fe9e031..95481fb 100644 --- a/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@llvm14.snap +++ b/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@llvm14.snap @@ -29,54 +29,45 @@ entry_block: ; preds = %alloca_block unreachable 9: ; preds = %entry_block - %10 = fdiv double %0, 2.000000e+00 - %11 = call double @llvm.floor.f64(double %10) - %12 = fsub double %10, %11 - %13 = fmul double %12, 2.000000e+00 - %14 = fcmp oeq double %13, 0x7FF0000000000000 - %15 = fcmp oeq double %13, 0xFFF0000000000000 - %16 = fcmp uno double %13, 0.000000e+00 - %17 = or i1 %14, %15 - %18 = or i1 %17, %16 - %19 = xor i1 %18, true - %20 = insertvalue { double } undef, double %13, 0 - %21 = insertvalue { i32, {}, { double } } { i32 1, {} poison, { double } poison }, { double } %20, 2 - %22 = select i1 %19, { i32, {}, { double } } %21, { i32, {}, { double } } { i32 0, {} undef, { double } poison } - %23 = extractvalue { i32, {}, { double } } %22, 0 - switch i32 %23, label %24 [ - i32 1, label %26 + %10 = fcmp oeq double %0, 0x7FF0000000000000 + %11 = fcmp oeq double %0, 0xFFF0000000000000 + %12 = fcmp uno double %0, 0.000000e+00 + %13 = or i1 %10, %11 + %14 = or i1 %13, %12 + %15 = xor i1 %14, true + %16 = insertvalue { double } undef, double %0, 0 + %17 = insertvalue { i32, {}, { double } } { i32 1, {} poison, { double } poison }, { double } %16, 2 + %18 = select i1 %15, { i32, {}, { double } } %17, { i32, {}, { double } } { i32 0, {} undef, { double } poison } + %19 = extractvalue { i32, {}, { double } } %18, 0 + switch i32 %19, label %20 [ + i32 1, label %22 ] -24: ; preds = %9 - %25 = extractvalue { i32, {}, { double } } %22, 1 +20: ; preds = %9 + %21 = extractvalue { i32, {}, { double } } %18, 1 br label %cond_7_case_0 -26: ; preds = %9 - %27 = extractvalue { i32, {}, { double } } %22, 2 - %28 = extractvalue { double } %27, 0 +22: ; preds = %9 + %23 = extractvalue { i32, {}, { double } } %18, 2 + %24 = extractvalue { double } %23, 0 br label %cond_7_case_1 -cond_7_case_0: ; preds = %24 - %29 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 0 - %30 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 1 - %31 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %29, i8* %30) +cond_7_case_0: ; preds = %20 + %25 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 0 + %26 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 1 + %27 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %25, i8* %26) call void @abort() br label %cond_exit_7 -cond_7_case_1: ; preds = %26 +cond_7_case_1: ; preds = %22 br label %cond_exit_7 cond_exit_7: ; preds = %cond_7_case_1, %cond_7_case_0 - %"0.0" = phi double [ 0.000000e+00, %cond_7_case_0 ], [ %28, %cond_7_case_1 ] - %32 = fadd double %0, %"0.0" + %"0.0" = phi double [ 0.000000e+00, %cond_7_case_0 ], [ %24, %cond_7_case_1 ] + %28 = fadd double %0, %"0.0" ret void } declare i32 @printf(i8*, ...) declare void @abort() - -; Function Attrs: nofree nosync nounwind readnone speculatable willreturn -declare double @llvm.floor.f64(double) #0 - -attributes #0 = { nofree nosync nounwind readnone speculatable willreturn } diff --git a/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@pre-mem2reg@llvm14.snap b/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@pre-mem2reg@llvm14.snap index c07ff82..02ef0aa 100644 --- a/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@pre-mem2reg@llvm14.snap +++ b/src/extension/snapshots/hugr_llvm__extension__rotation__test__emit_all_ops@pre-mem2reg@llvm14.snap @@ -44,51 +44,47 @@ entry_block: ; preds = %alloca_block 9: ; preds = %entry_block store double %"2_01", double* %"4_0", align 8 %"4_02" = load double, double* %"4_0", align 8 - %10 = fdiv double %"4_02", 2.000000e+00 - %11 = call double @llvm.floor.f64(double %10) - %12 = fsub double %10, %11 - %13 = fmul double %12, 2.000000e+00 - store double %13, double* %"5_0", align 8 + store double %"4_02", double* %"5_0", align 8 %"5_03" = load double, double* %"5_0", align 8 - %14 = fcmp oeq double %"5_03", 0x7FF0000000000000 - %15 = fcmp oeq double %"5_03", 0xFFF0000000000000 - %16 = fcmp uno double %"5_03", 0.000000e+00 - %17 = or i1 %14, %15 - %18 = or i1 %17, %16 - %19 = xor i1 %18, true - %20 = insertvalue { double } undef, double %"5_03", 0 - %21 = insertvalue { i32, {}, { double } } { i32 1, {} poison, { double } poison }, { double } %20, 2 - %22 = select i1 %19, { i32, {}, { double } } %21, { i32, {}, { double } } { i32 0, {} undef, { double } poison } - store { i32, {}, { double } } %22, { i32, {}, { double } }* %"6_0", align 8 + %10 = fcmp oeq double %"5_03", 0x7FF0000000000000 + %11 = fcmp oeq double %"5_03", 0xFFF0000000000000 + %12 = fcmp uno double %"5_03", 0.000000e+00 + %13 = or i1 %10, %11 + %14 = or i1 %13, %12 + %15 = xor i1 %14, true + %16 = insertvalue { double } undef, double %"5_03", 0 + %17 = insertvalue { i32, {}, { double } } { i32 1, {} poison, { double } poison }, { double } %16, 2 + %18 = select i1 %15, { i32, {}, { double } } %17, { i32, {}, { double } } { i32 0, {} undef, { double } poison } + store { i32, {}, { double } } %18, { i32, {}, { double } }* %"6_0", align 8 %"6_04" = load { i32, {}, { double } }, { i32, {}, { double } }* %"6_0", align 8 - %23 = extractvalue { i32, {}, { double } } %"6_04", 0 - switch i32 %23, label %24 [ - i32 1, label %26 + %19 = extractvalue { i32, {}, { double } } %"6_04", 0 + switch i32 %19, label %20 [ + i32 1, label %22 ] -24: ; preds = %9 - %25 = extractvalue { i32, {}, { double } } %"6_04", 1 +20: ; preds = %9 + %21 = extractvalue { i32, {}, { double } } %"6_04", 1 br label %cond_7_case_0 -26: ; preds = %9 - %27 = extractvalue { i32, {}, { double } } %"6_04", 2 - %28 = extractvalue { double } %27, 0 - store double %28, double* %"08", align 8 +22: ; preds = %9 + %23 = extractvalue { i32, {}, { double } } %"6_04", 2 + %24 = extractvalue { double } %23, 0 + store double %24, double* %"08", align 8 br label %cond_7_case_1 -cond_7_case_0: ; preds = %24 +cond_7_case_0: ; preds = %20 store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, { i32, i8* }* %"12_0", align 8 %"12_06" = load { i32, i8* }, { i32, i8* }* %"12_0", align 8 - %29 = extractvalue { i32, i8* } %"12_06", 0 - %30 = extractvalue { i32, i8* } %"12_06", 1 - %31 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %29, i8* %30) + %25 = extractvalue { i32, i8* } %"12_06", 0 + %26 = extractvalue { i32, i8* } %"12_06", 1 + %27 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %25, i8* %26) call void @abort() store double 0.000000e+00, double* %"13_0", align 8 %"13_07" = load double, double* %"13_0", align 8 store double %"13_07", double* %"0", align 8 br label %cond_exit_7 -cond_7_case_1: ; preds = %26 +cond_7_case_1: ; preds = %22 %"09" = load double, double* %"08", align 8 store double %"09", double* %"15_0", align 8 %"15_010" = load double, double* %"15_0", align 8 @@ -100,16 +96,11 @@ cond_exit_7: ; preds = %cond_7_case_1, %con store double %"05", double* %"7_0", align 8 %"4_011" = load double, double* %"4_0", align 8 %"7_012" = load double, double* %"7_0", align 8 - %32 = fadd double %"4_011", %"7_012" - store double %32, double* %"17_0", align 8 + %28 = fadd double %"4_011", %"7_012" + store double %28, double* %"17_0", align 8 ret void } declare i32 @printf(i8*, ...) declare void @abort() - -; Function Attrs: nofree nosync nounwind readnone speculatable willreturn -declare double @llvm.floor.f64(double) #0 - -attributes #0 = { nofree nosync nounwind readnone speculatable willreturn }