diff --git a/src/harness/check.rs b/src/harness/check.rs index 87e771e..e5be8f1 100644 --- a/src/harness/check.rs +++ b/src/harness/check.rs @@ -1,5 +1,6 @@ use console::Style; use harness::run::{FuncBuffer, ValBuffer}; +use kdl_script::types::PrimitiveTy; use kdl_script::types::Ty; use tracing::{error, info}; @@ -150,6 +151,17 @@ impl TestHarness { let callee = enum_variant_name(enum_ty, callee_tag); return Err(tag_error(types, &expected_val, expected, caller, callee)); } + } else if let Ty::Primitive(PrimitiveTy::Bool) = types.realize_ty(expected_val.ty) { + let expected_tag = expected_val.generate_idx(2); + let caller_tag = load_tag(caller_val); + let callee_tag = load_tag(callee_val); + + if caller_tag != expected_tag || callee_tag != expected_tag { + let expected = bool_variant_name(expected_tag, expected_tag); + let caller = bool_variant_name(expected_tag, caller_tag); + let callee = bool_variant_name(expected_tag, callee_tag); + return Err(tag_error(types, &expected_val, expected, caller, callee)); + } } else if caller_val.bytes != callee_val.bytes { // General case, just get a pile of bytes to span both values let func = expected_val.func(); @@ -201,6 +213,25 @@ fn enum_variant_name(enum_ty: &kdl_script::types::EnumTy, tag: usize) -> String format!("{enum_name}::{variant_name}") } +fn bool_variant_name(expected_tag: usize, tag: usize) -> String { + // Because we're using the tag variant machinery, this code is a bit weird, + // because we essentially get passed Option for `tag`, where we get + // None when the wrong path is taken. + // + // So to figure out what variant a bool is supposed to have, we work out + // what variant the expected_tag has, and then either say "the same or opposite" + let bools = ["false", "true"]; + let expected_bool = bools[expected_tag]; + let unexpected_bool = bools[1 - expected_tag]; + + let res = if tag == expected_tag { + expected_bool + } else { + unexpected_bool + }; + res.to_owned() +} + fn tag_error( types: &kdl_script::TypedProgram, expected_val: &ValueRef, diff --git a/src/harness/generate.rs b/src/harness/generate.rs index 622a944..667ff13 100644 --- a/src/harness/generate.rs +++ b/src/harness/generate.rs @@ -43,11 +43,11 @@ impl TestHarness { .or_insert_with(|| Arc::new(OnceCell::new())) .clone(); // Either acquire the cached result, or make it + info!("generating {}", &src_path); let _ = once .get_or_try_init(|| { let toolchain = self.toolchain_by_test_key(key, call_side); let options = key.options.clone(); - info!("generating {}", &src_path); generate_src( &src_path, toolchain, diff --git a/src/harness/vals.rs b/src/harness/vals.rs index 1bf7bfe..800690e 100644 --- a/src/harness/vals.rs +++ b/src/harness/vals.rs @@ -454,6 +454,10 @@ impl ValueGenerator { }; rng.gen_range(0..len) } + pub fn generate_bool(&self) -> bool { + let idx = self.generate_idx(2); + idx == 1 + } pub fn generate_u8(&self) -> u8 { let mut buf = [0; 1]; self.fill_bytes(&mut buf); diff --git a/src/toolchains/c/init.rs b/src/toolchains/c/init.rs index c52c2db..1c3f629 100644 --- a/src/toolchains/c/init.rs +++ b/src/toolchains/c/init.rs @@ -58,7 +58,7 @@ impl CcToolchain { write!(f, "{val}")? } } - PrimitiveTy::Bool => write!(f, "true")?, + PrimitiveTy::Bool => write!(f, "{}", val.generate_bool())?, PrimitiveTy::Ptr => { if true { write!(f, "(void*){:#X}ull", val.generate_u64())? diff --git a/src/toolchains/c/write.rs b/src/toolchains/c/write.rs index 1bf4610..f8cf24a 100644 --- a/src/toolchains/c/write.rs +++ b/src/toolchains/c/write.rs @@ -65,6 +65,30 @@ impl CcToolchain { vals: &mut ArgValuesIter, ) -> Result<(), GenerateError> { match state.types.realize_ty(var_ty) { + Ty::Primitive(PrimitiveTy::Bool) => { + // bool is basically an enum with variants "false" (0) and "true" (1) + let tag_generator = vals.next_val(); + let cond = tag_generator.generate_bool(); + if tag_generator.should_write_val(&state.options) { + writeln!(f, "if ({from}) {{")?; + f.add_indent(1); + if cond { + self.write_tag_field(f, state, to, from, 1, &tag_generator)?; + } else { + self.write_error_tag_field(f, state, to, &tag_generator)?; + } + f.sub_indent(1); + writeln!(f, "}} else {{")?; + f.add_indent(1); + if !cond { + self.write_tag_field(f, state, to, from, 0, &tag_generator)?; + } else { + self.write_error_tag_field(f, state, to, &tag_generator)?; + } + f.sub_indent(1); + writeln!(f, "}}")?; + } + } Ty::Primitive(_) => { // Hey an actual leaf, report it (and burn a value) let val = vals.next_val(); diff --git a/src/toolchains/rust/init.rs b/src/toolchains/rust/init.rs index f8a078c..b13771e 100644 --- a/src/toolchains/rust/init.rs +++ b/src/toolchains/rust/init.rs @@ -28,7 +28,7 @@ impl RustcToolchain { PrimitiveTy::F32 => write!(f, "f32::from_bits({})", val.generate_u32())?, PrimitiveTy::F64 => write!(f, "f64::from_bits({})", val.generate_u64())?, - PrimitiveTy::Bool => write!(f, "true")?, + PrimitiveTy::Bool => write!(f, "{}", val.generate_bool())?, PrimitiveTy::Ptr => { if true { write!(f, "{:#X}u64 as *mut ()", val.generate_u64())? diff --git a/src/toolchains/rust/write.rs b/src/toolchains/rust/write.rs index 4ce837c..e2f3203 100644 --- a/src/toolchains/rust/write.rs +++ b/src/toolchains/rust/write.rs @@ -85,6 +85,30 @@ impl RustcToolchain { vals: &mut ArgValuesIter, ) -> Result<(), GenerateError> { match state.types.realize_ty(var_ty) { + Ty::Primitive(PrimitiveTy::Bool) => { + // bool is basically an enum with variants "false" (0) and "true" (1) + let tag_generator = vals.next_val(); + let cond = tag_generator.generate_bool(); + if tag_generator.should_write_val(&state.options) { + writeln!(f, "if {from} {{")?; + f.add_indent(1); + if cond { + self.write_tag_field(f, state, to, from, 1, &tag_generator)?; + } else { + self.write_error_tag_field(f, state, to, &tag_generator)?; + } + f.sub_indent(1); + writeln!(f, "}} else {{")?; + f.add_indent(1); + if !cond { + self.write_tag_field(f, state, to, from, 0, &tag_generator)?; + } else { + self.write_error_tag_field(f, state, to, &tag_generator)?; + } + f.sub_indent(1); + writeln!(f, "}}")?; + } + } Ty::Primitive(_) => { // Hey an actual leaf, report it (and burn a value) let val = vals.next_val();