diff --git a/crates/steel-core/src/parser/ast.rs b/crates/steel-core/src/parser/ast.rs index d775250a2..6965c78d2 100644 --- a/crates/steel-core/src/parser/ast.rs +++ b/crates/steel-core/src/parser/ast.rs @@ -202,7 +202,7 @@ impl TryFrom<&SteelVal> for ExprKind { BigNum(x) => Ok(ExprKind::Atom(Atom::new(SyntaxObject::default( IntegerLiteral(MaybeBigInt::Big(x.unwrap())), )))), - + Complex(_) => unimplemented!("Complex numbers not fully supported yet. See https://github.com/mattwparas/steel/issues/62 for current details."), VectorV(lst) => { let items: std::result::Result, &'static str> = lst.iter().map(|x| inner_try_from(x, depth + 1)).collect(); diff --git a/crates/steel-core/src/primitives/nums.rs b/crates/steel-core/src/primitives/nums.rs index 8071ee6d0..ce6ccc369 100644 --- a/crates/steel-core/src/primitives/nums.rs +++ b/crates/steel-core/src/primitives/nums.rs @@ -1,5 +1,5 @@ -use crate::rvals::{IntoSteelVal, Result, SteelVal}; -use crate::steel_vm::primitives::numberp; +use crate::rvals::{IntoSteelVal, Result, SteelComplex, SteelVal}; +use crate::steel_vm::primitives::{numberp, realp}; use crate::stop; use num::{BigInt, BigRational, CheckedAdd, CheckedMul, Integer, Rational32, ToPrimitive}; use std::ops::Neg; @@ -17,7 +17,7 @@ fn ensure_args_are_numbers(op: &str, args: &[SteelVal]) -> Result<()> { /// /// # Precondition /// - `x` and `y` must be valid numerical types. -fn multiply_unchecked(x: &SteelVal, y: &SteelVal) -> Result { +fn multiply_two(x: &SteelVal, y: &SteelVal) -> Result { match (x, y) { (SteelVal::NumV(x), SteelVal::NumV(y)) => (x * y).into_steelval(), (SteelVal::NumV(x), SteelVal::IntV(y)) | (SteelVal::IntV(y), SteelVal::NumV(x)) => { @@ -91,6 +91,12 @@ fn multiply_unchecked(x: &SteelVal, y: &SteelVal) -> Result { (x.as_ref() * y.as_ref()).into_steelval() } (SteelVal::BigNum(x), SteelVal::BigNum(y)) => (x.as_ref() * y.as_ref()).into_steelval(), + // Complex numbers. + (SteelVal::Complex(x), SteelVal::Complex(y)) => multiply_complex(x, y), + (SteelVal::Complex(x), y) | (y, SteelVal::Complex(x)) => { + let y = SteelComplex::new(y.clone(), SteelVal::IntV(0)); + multiply_complex(x, &y) + } _ => unreachable!(), } } @@ -101,13 +107,13 @@ fn multiply_primitive_impl(args: &[SteelVal]) -> Result { match args { [] => 1.into_steelval(), [x] => x.clone().into_steelval(), - [x, y] => multiply_unchecked(x, y).into_steelval(), + [x, y] => multiply_two(x, y).into_steelval(), [x, y, zs @ ..] => { - let mut res = multiply_unchecked(x, y)?; + let mut res = multiply_two(x, y)?; for z in zs { // TODO: This use case could be optimized to reuse state instead of creating a new // object each time. - res = multiply_unchecked(&res, &z)?; + res = multiply_two(&res, &z)?; } res.into_steelval() } @@ -134,8 +140,8 @@ pub fn divide_primitive(args: &[SteelVal]) -> Result { Err(_) => BigRational::new(BigInt::from(1), BigInt::from(*n)).into_steelval(), }, SteelVal::NumV(n) => n.recip().into_steelval(), - SteelVal::Rational(f) => f.recip().into_steelval(), - SteelVal::BigRational(f) => f.recip().into_steelval(), + SteelVal::Rational(r) => r.recip().into_steelval(), + SteelVal::BigRational(r) => r.recip().into_steelval(), SteelVal::BigNum(n) => BigRational::new(1.into(), n.as_ref().clone()).into_steelval(), unexpected => { stop!(TypeMismatch => "/ expects a number, but found: {:?}", unexpected) @@ -146,18 +152,20 @@ pub fn divide_primitive(args: &[SteelVal]) -> Result { [] => stop!(ArityMismatch => "/ requires at least one argument"), [x] => recip(x), // TODO: Provide custom implementation to optimize by joining the multiply and recip calls. - [x, y] => multiply_unchecked(x, &recip(y)?), + [x, y] => multiply_two(x, &recip(y)?), [x, ys @ ..] => { let d = multiply_primitive_impl(ys)?; - multiply_unchecked(&x, &recip(&d)?) + multiply_two(&x, &recip(&d)?) } } } -#[steel_derive::native(name = "-", constant = true, arity = "AtLeast(1)")] -pub fn subtract_primitive(args: &[SteelVal]) -> Result { - ensure_args_are_numbers("-", args)?; - let negate = |x: &SteelVal| match x { +/// Negate a number. +/// +/// # Precondition +/// `value` must be a number. +fn negate(value: &SteelVal) -> Result { + match value { SteelVal::NumV(x) => (-x).into_steelval(), SteelVal::IntV(x) => match x.checked_neg() { Some(res) => res.into_steelval(), @@ -171,18 +179,28 @@ pub fn subtract_primitive(args: &[SteelVal]) -> Result { }, SteelVal::BigRational(x) => x.as_ref().neg().into_steelval(), SteelVal::BigNum(x) => x.as_ref().clone().neg().into_steelval(), + SteelVal::Complex(x) => negate_complex(x), _ => unreachable!(), - }; + } +} + +#[steel_derive::native(name = "-", constant = true, arity = "AtLeast(1)")] +pub fn subtract_primitive(args: &[SteelVal]) -> Result { + ensure_args_are_numbers("-", args)?; match args { [] => stop!(TypeMismatch => "- requires at least one argument"), [x] => negate(x), [x, ys @ ..] => { let y = negate(&add_primitive(ys)?)?; - add_primitive(&[x.clone(), y]) + add_two(x, &y) } } } +/// Adds two numbers. +/// +/// # Precondition +/// x and y must be valid numbers. pub fn add_two(x: &SteelVal, y: &SteelVal) -> Result { match (x, y) { // Simple integer case. Probably very common. @@ -259,6 +277,12 @@ pub fn add_two(x: &SteelVal, y: &SteelVal) -> Result { res += *y; res.into_steelval() } + // Complex numbers + (SteelVal::Complex(x), SteelVal::Complex(y)) => add_complex(x, y), + (SteelVal::Complex(x), y) | (y, SteelVal::Complex(x)) => { + debug_assert!(realp(y)); + add_complex(x, &SteelComplex::new(y.clone(), SteelVal::IntV(0))) + } _ => unreachable!(), } } @@ -280,17 +304,48 @@ pub fn add_primitive(args: &[SteelVal]) -> Result { } } +#[cold] +fn multiply_complex(x: &SteelComplex, y: &SteelComplex) -> Result { + // TODO: Optimize the implementation if needed. + let real = add_two( + &multiply_two(&x.re, &y.re)?, + &negate(&multiply_two(&x.im, &y.im)?)?, + )?; + let im = add_two(&multiply_two(&x.re, &y.im)?, &multiply_two(&x.im, &y.re)?)?; + SteelComplex::new(real, im).into_steelval() +} + +#[cold] +fn negate_complex(x: &SteelComplex) -> Result { + // TODO: Optimize the implementation if needed. + SteelComplex::new(negate(&x.re)?, negate(&x.im)?).into_steelval() +} + +#[cold] +fn add_complex(x: &SteelComplex, y: &SteelComplex) -> Result { + // TODO: Optimize the implementation if needed. + SteelComplex::new(add_two(&x.re, &y.re)?, add_two(&x.im, &y.im)?).into_steelval() +} + #[steel_derive::function(name = "exact?", constant = true)] pub fn exactp(value: &SteelVal) -> bool { - matches!( - value, - SteelVal::IntV(_) | SteelVal::BigNum(_) | SteelVal::Rational(_) | SteelVal::BigRational(_) - ) + match value { + SteelVal::IntV(_) + | SteelVal::BigNum(_) + | SteelVal::Rational(_) + | SteelVal::BigRational(_) => true, + SteelVal::Complex(x) => exactp(&x.re) && exactp(&x.im), + _ => false, + } } #[steel_derive::function(name = "inexact?", constant = true)] pub fn inexactp(value: &SteelVal) -> bool { - matches!(value, SteelVal::NumV(_)) + match value { + SteelVal::NumV(_) => true, + SteelVal::Complex(x) => inexactp(&x.re) || inexactp(&x.im), + _ => false, + } } pub struct NumOperations {} diff --git a/crates/steel-core/src/rvals.rs b/crates/steel-core/src/rvals.rs index 827965c89..e4514cd82 100644 --- a/crates/steel-core/src/rvals.rs +++ b/crates/steel-core/src/rvals.rs @@ -9,7 +9,10 @@ use crate::{ tokens::TokenType, }, rerrs::{ErrorKind, SteelErr}, - steel_vm::vm::{threads::closure_into_serializable, BuiltInSignature, Continuation}, + steel_vm::{ + primitives::realp, + vm::{threads::closure_into_serializable, BuiltInSignature, Continuation}, + }, values::port::SteelPort, values::{ closed::{Heap, HeapRef, MarkAndSweepContext}, @@ -22,7 +25,7 @@ use crate::{ }, values::{functions::BoxedDynFunction, structs::UserDefinedStruct}, }; - +use std::vec::IntoIter; use std::{ any::{Any, TypeId}, cell::{Ref, RefCell, RefMut}, @@ -40,8 +43,6 @@ use std::{ task::Context, }; -use std::vec::IntoIter; - // TODO #[macro_export] macro_rules! list { @@ -71,7 +72,7 @@ use futures_util::future::Shared; use futures_util::FutureExt; use crate::values::lists::List; -use num::{BigInt, BigRational, Rational32, ToPrimitive}; +use num::{BigInt, BigRational, Rational32, Signed, ToPrimitive, Zero}; use steel_parser::tokens::MaybeBigInt; use self::cycles::{CycleDetector, IterativeDropHandler}; @@ -1216,6 +1217,53 @@ pub enum SteelVal { BigNum(Gc), // Like Rational but supports larger numerators and denominators. BigRational(Gc), + // A complex number. + Complex(Gc), +} + +/// Contains a complex number. +/// +/// TODO: Optimize the contents of complex value. Holding `SteelVal` makes it easier to use existing +/// operations but a more specialized representation may be faster. +#[derive(Clone, Hash, PartialEq)] +pub struct SteelComplex { + /// The real part of the complex number. + pub re: SteelVal, + /// The imaginary part of the complex number. + pub im: SteelVal, +} + +impl SteelComplex { + pub fn new(real: SteelVal, imaginary: SteelVal) -> SteelComplex { + SteelComplex { + re: real, + im: imaginary, + } + } +} + +impl IntoSteelVal for SteelComplex { + fn into_steelval(self) -> Result { + Ok(match self.im { + NumV(n) if n.is_zero() => self.re, + IntV(0) => self.re, + _ => SteelVal::Complex(Gc::new(self)), + }) + } +} + +impl SteelComplex { + /// Returns `true` if the imaginary part is negative. + fn imaginary_is_negative(&self) -> bool { + match &self.im { + NumV(x) => x.is_negative(), + IntV(x) => x.is_negative(), + Rational(x) => x.is_negative(), + BigNum(x) => x.is_negative(), + SteelVal::BigRational(x) => x.is_negative(), + _ => unreachable!(), + } + } } impl SteelVal { @@ -1592,11 +1640,12 @@ impl Hash for SteelVal { NumV(n) => n.to_string().hash(state), IntV(i) => i.hash(state), Rational(f) => f.hash(state), + BigNum(n) => n.hash(state), + BigRational(f) => f.hash(state), + Complex(x) => x.hash(state), CharV(c) => c.hash(state), ListV(l) => l.hash(state), CustomStruct(s) => s.hash(state), - BigNum(n) => n.hash(state), - BigRational(f) => f.hash(state), // Pair(cell) => { // cell.hash(state); // } @@ -1959,10 +2008,14 @@ pub fn number_equality(left: &SteelVal, right: &SteelVal) -> Result { | (BigRational(_), BigNum(_)) | (BigNum(_), BigRational(_)) => false, (IntV(_), BigNum(_)) | (BigNum(_), IntV(_)) => false, + (Complex(x), Complex(y)) => { + number_equality(&x.re, &y.re)? == BoolV(true) + && number_equality(&x.im, &y.re)? == BoolV(true) + } + (Complex(_), _) | (_, Complex(_)) => false, _ => stop!(TypeMismatch => "= expects two numbers, found: {:?} and {:?}", left, right), }; - - Ok(SteelVal::BoolV(result)) + Ok(BoolV(result)) } fn partial_cmp_f64(l: &impl ToPrimitive, r: &impl ToPrimitive) -> Option { @@ -1972,8 +2025,8 @@ fn partial_cmp_f64(l: &impl ToPrimitive, r: &impl ToPrimitive) -> Option Option { - // TODO: Attempt to avoid converting to f64 for cases below as it may lead to precision - // loss at tiny and large values. + // TODO: Attempt to avoid converting to f64 for cases below as it may lead to precision loss + // at tiny and large values. match (self, other) { (IntV(l), IntV(r)) => l.partial_cmp(r), (IntV(l), NumV(r)) => partial_cmp_f64(l, r), @@ -2002,7 +2055,15 @@ impl PartialOrd for SteelVal { (BigRational(l), BigNum(r)) => partial_cmp_f64(l.as_ref(), r.as_ref()), (StringV(s), StringV(o)) => s.partial_cmp(o), (CharV(l), CharV(r)) => l.partial_cmp(r), - _ => None, // unimplemented for other types + (l, r) => { + // All real numbers (not complex) should have order defined. + debug_assert!( + !(realp(l) && realp(r)), + "Numbers {l:?} and {r:?} should implement partial_cmp" + ); + // Unimplemented for other types + None + } } } } diff --git a/crates/steel-core/src/rvals/cycles.rs b/crates/steel-core/src/rvals/cycles.rs index 6717a1e92..ccad14428 100644 --- a/crates/steel-core/src/rvals/cycles.rs +++ b/crates/steel-core/src/rvals/cycles.rs @@ -160,6 +160,10 @@ impl CycleDetector { IntV(x) => write!(f, "{x}"), Rational(x) => write!(f, "{n}/{d}", n = x.numer(), d = x.denom()), BigRational(x) => write!(f, "{n}/{d}", n = x.numer(), d = x.denom()), + Complex(x) if x.imaginary_is_negative() => { + write!(f, "{re}{im}i", re = x.re, im = x.im) + } + Complex(x) => write!(f, "{re}+{im}i", re = x.re, im = x.im), StringV(s) => write!(f, "{s:?}"), BigNum(b) => write!(f, "{}", b.as_ref()), CharV(c) => { @@ -284,6 +288,7 @@ impl CycleDetector { IntV(x) => write!(f, "{x}"), Rational(x) => write!(f, "{n}/{d}", n = x.numer(), d = x.denom()), BigRational(x) => write!(f, "{n}/{d}", n = x.numer(), d = x.denom()), + Complex(x) => write!(f, "{re}+{im}i", re = x.re, im = x.im), StringV(s) => write!(f, "{s:?}"), CharV(c) => { if c.is_ascii_control() { @@ -1109,6 +1114,7 @@ impl<'a> BreadthFirstSearchSteelValVisitor for IterativeDropHandler<'a> { Rational(x) => self.visit_rational(x), BigRational(x) => self.visit_bigrational(x), BigNum(b) => self.visit_bignum(b), + Complex(_) => unimplemented!(), CharV(c) => self.visit_char(c), VectorV(v) => self.visit_immutable_vector(v), Void => self.visit_void(), @@ -1174,6 +1180,7 @@ pub trait BreadthFirstSearchSteelValVisitor { Rational(x) => self.visit_rational(x), BigRational(x) => self.visit_bigrational(x), BigNum(b) => self.visit_bignum(b), + Complex(_) => unimplemented!(), CharV(c) => self.visit_char(c), VectorV(v) => self.visit_immutable_vector(v), Void => self.visit_void(), @@ -1265,6 +1272,7 @@ pub trait BreadthFirstSearchSteelValReferenceVisitor<'a> { IntV(i) => self.visit_int(*i), Rational(x) => self.visit_rational(*x), BigRational(x) => self.visit_bigrational(x), + Complex(_) => unimplemented!(), CharV(c) => self.visit_char(*c), VectorV(v) => self.visit_immutable_vector(v), Void => self.visit_void(), @@ -1909,6 +1917,7 @@ impl PartialEq for SteelVal { (Rational(l), Rational(r)) => l == r, (BigRational(l), BigRational(r)) => l == r, (BigNum(l), BigNum(r)) => l == r, + (Complex(l), Complex(r)) => l == r, (StringV(l), StringV(r)) => l == r, (SymbolV(l), SymbolV(r)) => l == r, (CharV(l), CharV(r)) => l == r, @@ -1923,76 +1932,78 @@ impl PartialEq for SteelVal { // (CustomStruct(l), CustomStruct(r)) => l == r, // (Custom(l), Custom(r)) => Gc::ptr_eq(l, r), // (HeapAllocated(l), HeapAllocated(r)) => l.get() == r.get(), - (left, right) => LEFT_QUEUE.with(|left_queue| { - RIGHT_QUEUE.with(|right_queue| { - VISITED_SET.with(|visited_set| { - match ( - left_queue.try_borrow_mut(), - right_queue.try_borrow_mut(), - visited_set.try_borrow_mut(), - ) { - (Ok(mut left_queue), Ok(mut right_queue), Ok(mut visited_set)) => { - let mut equality_handler = RecursiveEqualityHandler { - left: EqualityVisitor { - queue: &mut left_queue, - }, - right: EqualityVisitor { - queue: &mut right_queue, - }, - visited: &mut visited_set, - // found_mutable_object: false, - }; - - let res = - equality_handler.compare_equality(left.clone(), right.clone()); - - // EQ_DEPTH.with(|x| x.set(0)); - - reset_eq_depth(); - - // Clean up! - equality_handler.left.queue.clear(); - equality_handler.right.queue.clear(); - equality_handler.visited.clear(); - - res - } - _ => { - let mut left_queue = Vec::new(); - let mut right_queue = Vec::new(); + (left, right) => { + LEFT_QUEUE.with(|left_queue| { + RIGHT_QUEUE.with(|right_queue| { + VISITED_SET.with(|visited_set| { + match ( + left_queue.try_borrow_mut(), + right_queue.try_borrow_mut(), + visited_set.try_borrow_mut(), + ) { + (Ok(mut left_queue), Ok(mut right_queue), Ok(mut visited_set)) => { + let mut equality_handler = RecursiveEqualityHandler { + left: EqualityVisitor { + queue: &mut left_queue, + }, + right: EqualityVisitor { + queue: &mut right_queue, + }, + visited: &mut visited_set, + // found_mutable_object: false, + }; + + let res = equality_handler + .compare_equality(left.clone(), right.clone()); + + // EQ_DEPTH.with(|x| x.set(0)); + + reset_eq_depth(); + + // Clean up! + equality_handler.left.queue.clear(); + equality_handler.right.queue.clear(); + equality_handler.visited.clear(); + + res + } + _ => { + let mut left_queue = Vec::new(); + let mut right_queue = Vec::new(); - let mut visited_set = fxhash::FxHashSet::default(); + let mut visited_set = fxhash::FxHashSet::default(); - // EQ_DEPTH.with(|x| x.set(x.get() + 1)); + // EQ_DEPTH.with(|x| x.set(x.get() + 1)); - increment_eq_depth(); + increment_eq_depth(); - // println!("{}", EQ_DEPTH.with(|x| x.get())); + // println!("{}", EQ_DEPTH.with(|x| x.get())); - let mut equality_handler = RecursiveEqualityHandler { - left: EqualityVisitor { - queue: &mut left_queue, - }, - right: EqualityVisitor { - queue: &mut right_queue, - }, - visited: &mut visited_set, - // found_mutable_object: false, - }; + let mut equality_handler = RecursiveEqualityHandler { + left: EqualityVisitor { + queue: &mut left_queue, + }, + right: EqualityVisitor { + queue: &mut right_queue, + }, + visited: &mut visited_set, + // found_mutable_object: false, + }; - let res = - equality_handler.compare_equality(left.clone(), right.clone()); + let res = equality_handler + .compare_equality(left.clone(), right.clone()); - // EQ_DEPTH.with(|x| x.set(x.get() - 1)); + // EQ_DEPTH.with(|x| x.set(x.get() - 1)); - decrement_eq_depth(); + decrement_eq_depth(); - res + res + } } - } + }) }) }) - }), + } } } } diff --git a/crates/steel-core/src/steel_vm/primitives.rs b/crates/steel-core/src/steel_vm/primitives.rs index 82e3db2d3..fd067517b 100644 --- a/crates/steel-core/src/steel_vm/primitives.rs +++ b/crates/steel-core/src/steel_vm/primitives.rs @@ -34,8 +34,8 @@ use crate::{ rvals::{ as_underlying_type, cycles::{BreadthFirstSearchSteelValVisitor, SteelCycleCollector}, - FromSteelVal, FunctionSignature, MutFunctionSignature, SteelString, ITERATOR_FINISHED, - NUMBER_EQUALITY_DEFINITION, + FromSteelVal, FunctionSignature, MutFunctionSignature, SteelComplex, SteelString, + ITERATOR_FINISHED, NUMBER_EQUALITY_DEFINITION, }, steel_vm::{ builtin::{get_function_metadata, get_function_name, Arity}, @@ -683,9 +683,15 @@ pub fn numberp(value: &SteelVal) -> bool { | SteelVal::Rational(_) | SteelVal::BigRational(_) | SteelVal::NumV(_) + | SteelVal::Complex(_) ) } +#[steel_derive::function(name = "complex?", constant = true)] +pub fn complexp(value: &SteelVal) -> bool { + numberp(value) +} + #[steel_derive::function(name = "int?", constant = true)] fn intp(value: &SteelVal) -> bool { matches!(value, SteelVal::IntV(_) | SteelVal::BigNum(_)) @@ -702,7 +708,7 @@ fn floatp(value: &SteelVal) -> bool { } #[steel_derive::function(name = "real?", constant = true)] -fn realp(value: &SteelVal) -> bool { +pub fn realp(value: &SteelVal) -> bool { matches!( value, SteelVal::IntV(_) @@ -838,6 +844,7 @@ fn identity_module() -> BuiltInModule { // .register_value("int?", gen_pred!(IntV)) .register_native_fn_definition(NOT_DEFINITION) .register_native_fn_definition(NUMBERP_DEFINITION) + .register_native_fn_definition(COMPLEXP_DEFINITION) .register_native_fn_definition(INTP_DEFINITION) .register_native_fn_definition(INTEGERP_DEFINITION) .register_native_fn_definition(FLOATP_DEFINITION) @@ -901,6 +908,9 @@ fn exact_to_inexact(number: &SteelVal) -> Result { SteelVal::BigRational(f) => f.to_f64().unwrap().into_steelval(), SteelVal::NumV(n) => n.into_steelval(), SteelVal::BigNum(n) => Ok(SteelVal::NumV(n.to_f64().unwrap())), + SteelVal::Complex(x) => { + SteelComplex::new(exact_to_inexact(&x.re)?, exact_to_inexact(&x.im)?).into_steelval() + } _ => stop!(TypeMismatch => "exact->inexact expects a number type, found: {}", number), } } @@ -917,7 +927,7 @@ fn round(number: &SteelVal) -> Result { SteelVal::Rational(f) => f.round().into_steelval(), SteelVal::BigRational(f) => f.round().into_steelval(), SteelVal::BigNum(n) => Ok(SteelVal::BigNum(n.clone())), - _ => stop!(TypeMismatch => "round expects a number type, found: {}", number), + _ => stop!(TypeMismatch => "round expects a real number, found: {}", number), } } @@ -930,7 +940,7 @@ fn abs(number: &SteelVal) -> Result { SteelVal::Rational(f) => f.abs().into_steelval(), SteelVal::BigRational(f) => f.abs().into_steelval(), SteelVal::BigNum(n) => n.abs().into_steelval(), - _ => stop!(TypeMismatch => "abs expects a number type, found: {}", number), + _ => stop!(TypeMismatch => "abs expects a real number, found: {}", number), } } @@ -992,8 +1002,8 @@ fn expt(left: &SteelVal, right: &SteelVal) -> Result { .unwrap() .powf(r.to_f64().unwrap()) .into_steelval(), - _ => { - stop!(TypeMismatch => "expt expected two numbers") + (l, r) => { + stop!(TypeMismatch => "expt expected two numbers but found {} and {}", l, r) } } }