diff --git a/crates/steel-core/src/primitives/numbers.rs b/crates/steel-core/src/primitives/numbers.rs index 26a80a54f..c5b3d1ee1 100644 --- a/crates/steel-core/src/primitives/numbers.rs +++ b/crates/steel-core/src/primitives/numbers.rs @@ -488,7 +488,7 @@ fn log(args: &[SteelVal]) -> Result { match (first, &base) { (SteelVal::IntV(1), _) => Ok(SteelVal::IntV(0)), (SteelVal::IntV(_) | SteelVal::NumV(_), SteelVal::IntV(1)) => { - stop!(Generic => "log: divide by zero with args: {} and {}", first, base); + steelerr!(Generic => "log: divide by zero with args: {} and {}", first, base) } (SteelVal::IntV(arg), SteelVal::NumV(n)) => Ok(SteelVal::NumV((*arg as f64).log(*n))), (SteelVal::IntV(arg), SteelVal::IntV(base)) => Ok(SteelVal::IntV(arg.ilog(*base) as isize)), @@ -501,6 +501,42 @@ fn log(args: &[SteelVal]) -> Result { } } +/// Returns an integer that is closest (but not greater than) the square root of an integer and the +/// remainder. +/// +/// ```scheme +/// (exact-integer-sqrt x) => '(root rem) +/// (equal? x (+ (square root) rem)) => true +/// ``` +#[steel_derive::function(name = "exact-integer-sqrt", constant = true)] +fn exact_integer_sqrt(number: &SteelVal) -> Result { + match number { + SteelVal::IntV(x) if *x >= 0 => { + let (ans, rem) = exact_integer_impl(x); + (ans.into_steelval()?, rem.into_steelval()?).into_steelval() + } + SteelVal::BigNum(x) if !x.is_negative() => { + let (ans, rem) = exact_integer_impl(x.as_ref()); + (ans.into_steelval()?, rem.into_steelval()?).into_steelval() + } + _ => { + steelerr!(TypeMismatch => "exact-integer-sqrt expects a non-negative integer but found {number}") + } + } +} + +fn exact_integer_impl<'a, N>(target: &'a N) -> (N, N) +where + N: num::integer::Roots + Clone, + &'a N: std::ops::Mul<&'a N, Output = N>, + N: std::ops::Sub, +{ + let x = target.sqrt(); + let x_sq = x.clone() * x.clone(); + let rem = target.clone() - x_sq; + (x, rem) +} + fn ensure_args_are_numbers(op: &str, args: &[SteelVal]) -> Result<()> { for arg in args { if !numberp(arg) { @@ -780,7 +816,7 @@ impl NumOperations { Ok(SteelVal::IntV(n >> -m)) } } - _ => stop!(TypeMismatch => "arithmetic-shift expected 2 integers"), + _ => steelerr!(TypeMismatch => "arithmetic-shift expected 2 integers"), } }) } @@ -796,7 +832,7 @@ impl NumOperations { SteelVal::BigNum(n) => Ok(SteelVal::BoolV(n.is_even())), SteelVal::NumV(n) if n.fract() == 0.0 => (*n as i64).is_even().into_steelval(), _ => { - stop!(TypeMismatch => format!("even? requires an integer, found: {:?}", &args[0])) + steelerr!(TypeMismatch => format!("even? requires an integer, found: {:?}", &args[0])) } } }) @@ -813,7 +849,7 @@ impl NumOperations { SteelVal::BigNum(n) => Ok(SteelVal::BoolV(n.is_odd())), SteelVal::NumV(n) if n.fract() == 0.0 => (*n as i64).is_odd().into_steelval(), _ => { - stop!(TypeMismatch => format!("odd? requires an integer, found: {:?}", &args[0])) + steelerr!(TypeMismatch => format!("odd? requires an integer, found: {:?}", &args[0])) } } }) @@ -992,4 +1028,74 @@ mod num_op_tests { let expected = IntV(8); assert_eq!(got, expected); } + + #[test] + fn test_exact_integer_sqrt() { + assert_eq!( + exact_integer_sqrt(&0.into()), + (0.into_steelval().unwrap(), 0.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&1.into()), + (1.into_steelval().unwrap(), 0.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&2.into()), + (1.into_steelval().unwrap(), 1.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&2.into()), + (1.into_steelval().unwrap(), 1.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&3.into()), + (1.into_steelval().unwrap(), 2.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&4.into()), + (2.into_steelval().unwrap(), 0.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&5.into()), + (2.into_steelval().unwrap(), 1.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&6.into()), + (2.into_steelval().unwrap(), 2.into_steelval().unwrap()).into_steelval() + ); + assert_eq!( + exact_integer_sqrt(&7.into()), + (2.into_steelval().unwrap(), 3.into_steelval().unwrap()).into_steelval() + ); + } + + #[test] + fn test_exact_integer_sqrt_fails_on_negative_or_noninteger() { + assert!(exact_integer_sqrt(&(-7).into()).is_err()); + assert!(exact_integer_sqrt(&(-7).into()).is_err()); + assert!(exact_integer_sqrt(&Rational32::new(-1, 2).into_steelval().unwrap()).is_err()); + assert!(exact_integer_sqrt( + &BigInt::from_str("-10000000000000000000000000000000000001") + .unwrap() + .into_steelval() + .unwrap() + ) + .is_err()); + assert!(exact_integer_sqrt( + &num::BigRational::new( + BigInt::from_str("-10000000000000000000000000000000000001").unwrap(), + BigInt::from_str("2").unwrap() + ) + .into_steelval() + .unwrap() + ) + .is_err()); + assert!(exact_integer_sqrt(&(1.0).into()).is_err()); + assert!(exact_integer_sqrt( + &SteelComplex::new(1.into(), 1.into()) + .into_steelval() + .unwrap() + ) + .is_err()); + } } diff --git a/crates/steel-core/src/steel_vm/primitives.rs b/crates/steel-core/src/steel_vm/primitives.rs index b8949af76..fd1370923 100644 --- a/crates/steel-core/src/steel_vm/primitives.rs +++ b/crates/steel-core/src/steel_vm/primitives.rs @@ -842,6 +842,7 @@ fn number_module() -> BuiltInModule { .register_native_fn_definition(numbers::DENOMINATOR_DEFINITION) .register_native_fn_definition(numbers::EXACTP_DEFINITION) .register_native_fn_definition(numbers::EXACT_TO_INEXACT_DEFINITION) + .register_native_fn_definition(numbers::EXACT_INTEGER_SQRT_DEFINITION) .register_native_fn_definition(numbers::EXPT_DEFINITION) .register_native_fn_definition(numbers::EXP_DEFINITION) .register_native_fn_definition(numbers::FINITEP_DEFINITION) diff --git a/crates/steel-core/src/tests/success/numbers.scm b/crates/steel-core/src/tests/success/numbers.scm index 874a405fd..90e8b8d88 100644 --- a/crates/steel-core/src/tests/success/numbers.scm +++ b/crates/steel-core/src/tests/success/numbers.scm @@ -267,3 +267,18 @@ (log 100 10.0)) (assert-equal! 1.0 (log (exp 1))) + +(assert-equal! '(0 0) + (exact-integer-sqrt 0)) +(assert-equal! '(1 0) + (exact-integer-sqrt 1)) +(assert-equal! '(1 1) + (exact-integer-sqrt 2)) +(assert-equal! '(1 2) + (exact-integer-sqrt 3)) +(assert-equal! '(2 0) + (exact-integer-sqrt 4)) +(assert-equal! '(2 1) + (exact-integer-sqrt 5)) +(assert-equal! '(10000000000000000000000 4) + (exact-integer-sqrt 100000000000000000000000000000000000000000004))