Skip to content

Commit

Permalink
Support exact-integer-sqrt (#174)
Browse files Browse the repository at this point in the history
* Support exact-integer-sqrt

* Simplify with steel_derive::function

* Add error cases.
  • Loading branch information
wmedrano authored Mar 2, 2024
1 parent 1ab7b1f commit 9856a22
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 4 deletions.
114 changes: 110 additions & 4 deletions crates/steel-core/src/primitives/numbers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ fn log(args: &[SteelVal]) -> Result<SteelVal> {
match (first, &base) {
(SteelVal::IntV(1), _) => Ok(SteelVal::IntV(0)),
(SteelVal::IntV(_) | SteelVal::NumV(_), SteelVal::IntV(1)) => {
stop!(Generic => "log: divide by zero with args: {} and {}", first, base);
steelerr!(Generic => "log: divide by zero with args: {} and {}", first, base)
}
(SteelVal::IntV(arg), SteelVal::NumV(n)) => Ok(SteelVal::NumV((*arg as f64).log(*n))),
(SteelVal::IntV(arg), SteelVal::IntV(base)) => Ok(SteelVal::IntV(arg.ilog(*base) as isize)),
Expand All @@ -501,6 +501,42 @@ fn log(args: &[SteelVal]) -> Result<SteelVal> {
}
}

/// Returns an integer that is closest (but not greater than) the square root of an integer and the
/// remainder.
///
/// ```scheme
/// (exact-integer-sqrt x) => '(root rem)
/// (equal? x (+ (square root) rem)) => true
/// ```
#[steel_derive::function(name = "exact-integer-sqrt", constant = true)]
fn exact_integer_sqrt(number: &SteelVal) -> Result<SteelVal> {
match number {
SteelVal::IntV(x) if *x >= 0 => {
let (ans, rem) = exact_integer_impl(x);
(ans.into_steelval()?, rem.into_steelval()?).into_steelval()
}
SteelVal::BigNum(x) if !x.is_negative() => {
let (ans, rem) = exact_integer_impl(x.as_ref());
(ans.into_steelval()?, rem.into_steelval()?).into_steelval()
}
_ => {
steelerr!(TypeMismatch => "exact-integer-sqrt expects a non-negative integer but found {number}")
}
}
}

fn exact_integer_impl<'a, N>(target: &'a N) -> (N, N)
where
N: num::integer::Roots + Clone,
&'a N: std::ops::Mul<&'a N, Output = N>,
N: std::ops::Sub<N, Output = N>,
{
let x = target.sqrt();
let x_sq = x.clone() * x.clone();
let rem = target.clone() - x_sq;
(x, rem)
}

fn ensure_args_are_numbers(op: &str, args: &[SteelVal]) -> Result<()> {
for arg in args {
if !numberp(arg) {
Expand Down Expand Up @@ -780,7 +816,7 @@ impl NumOperations {
Ok(SteelVal::IntV(n >> -m))
}
}
_ => stop!(TypeMismatch => "arithmetic-shift expected 2 integers"),
_ => steelerr!(TypeMismatch => "arithmetic-shift expected 2 integers"),
}
})
}
Expand All @@ -796,7 +832,7 @@ impl NumOperations {
SteelVal::BigNum(n) => Ok(SteelVal::BoolV(n.is_even())),
SteelVal::NumV(n) if n.fract() == 0.0 => (*n as i64).is_even().into_steelval(),
_ => {
stop!(TypeMismatch => format!("even? requires an integer, found: {:?}", &args[0]))
steelerr!(TypeMismatch => format!("even? requires an integer, found: {:?}", &args[0]))
}
}
})
Expand All @@ -813,7 +849,7 @@ impl NumOperations {
SteelVal::BigNum(n) => Ok(SteelVal::BoolV(n.is_odd())),
SteelVal::NumV(n) if n.fract() == 0.0 => (*n as i64).is_odd().into_steelval(),
_ => {
stop!(TypeMismatch => format!("odd? requires an integer, found: {:?}", &args[0]))
steelerr!(TypeMismatch => format!("odd? requires an integer, found: {:?}", &args[0]))
}
}
})
Expand Down Expand Up @@ -992,4 +1028,74 @@ mod num_op_tests {
let expected = IntV(8);
assert_eq!(got, expected);
}

#[test]
fn test_exact_integer_sqrt() {
assert_eq!(
exact_integer_sqrt(&0.into()),
(0.into_steelval().unwrap(), 0.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&1.into()),
(1.into_steelval().unwrap(), 0.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&2.into()),
(1.into_steelval().unwrap(), 1.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&2.into()),
(1.into_steelval().unwrap(), 1.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&3.into()),
(1.into_steelval().unwrap(), 2.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&4.into()),
(2.into_steelval().unwrap(), 0.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&5.into()),
(2.into_steelval().unwrap(), 1.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&6.into()),
(2.into_steelval().unwrap(), 2.into_steelval().unwrap()).into_steelval()
);
assert_eq!(
exact_integer_sqrt(&7.into()),
(2.into_steelval().unwrap(), 3.into_steelval().unwrap()).into_steelval()
);
}

#[test]
fn test_exact_integer_sqrt_fails_on_negative_or_noninteger() {
assert!(exact_integer_sqrt(&(-7).into()).is_err());
assert!(exact_integer_sqrt(&(-7).into()).is_err());
assert!(exact_integer_sqrt(&Rational32::new(-1, 2).into_steelval().unwrap()).is_err());
assert!(exact_integer_sqrt(
&BigInt::from_str("-10000000000000000000000000000000000001")
.unwrap()
.into_steelval()
.unwrap()
)
.is_err());
assert!(exact_integer_sqrt(
&num::BigRational::new(
BigInt::from_str("-10000000000000000000000000000000000001").unwrap(),
BigInt::from_str("2").unwrap()
)
.into_steelval()
.unwrap()
)
.is_err());
assert!(exact_integer_sqrt(&(1.0).into()).is_err());
assert!(exact_integer_sqrt(
&SteelComplex::new(1.into(), 1.into())
.into_steelval()
.unwrap()
)
.is_err());
}
}
1 change: 1 addition & 0 deletions crates/steel-core/src/steel_vm/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ fn number_module() -> BuiltInModule {
.register_native_fn_definition(numbers::DENOMINATOR_DEFINITION)
.register_native_fn_definition(numbers::EXACTP_DEFINITION)
.register_native_fn_definition(numbers::EXACT_TO_INEXACT_DEFINITION)
.register_native_fn_definition(numbers::EXACT_INTEGER_SQRT_DEFINITION)
.register_native_fn_definition(numbers::EXPT_DEFINITION)
.register_native_fn_definition(numbers::EXP_DEFINITION)
.register_native_fn_definition(numbers::FINITEP_DEFINITION)
Expand Down
15 changes: 15 additions & 0 deletions crates/steel-core/src/tests/success/numbers.scm
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,18 @@
(log 100 10.0))
(assert-equal! 1.0
(log (exp 1)))

(assert-equal! '(0 0)
(exact-integer-sqrt 0))
(assert-equal! '(1 0)
(exact-integer-sqrt 1))
(assert-equal! '(1 1)
(exact-integer-sqrt 2))
(assert-equal! '(1 2)
(exact-integer-sqrt 3))
(assert-equal! '(2 0)
(exact-integer-sqrt 4))
(assert-equal! '(2 1)
(exact-integer-sqrt 5))
(assert-equal! '(10000000000000000000000 4)
(exact-integer-sqrt 100000000000000000000000000000000000000000004))

0 comments on commit 9856a22

Please sign in to comment.