diff --git a/src/expression/deep.rs b/src/expression/deep.rs index e39e97e..9461b01 100644 --- a/src/expression/deep.rs +++ b/src/expression/deep.rs @@ -278,7 +278,14 @@ mod detail { while idx_tkn < parsed_tokens.len() { match &parsed_tokens[idx_tkn] { ParsedToken::Op(op) => { - if idx_tkn > 0 && parser::is_operator_binary(op, &parsed_tokens[idx_tkn - 1])? { + if parser::is_operator_binary( + op, + if idx_tkn == 0 { + None + } else { + Some(&parsed_tokens[idx_tkn - 1]) + }, + )? { bin_ops.push(op.bin()?); reprs_bin_ops.push(op.repr()); idx_tkn += 1; diff --git a/src/expression/flat.rs b/src/expression/flat.rs index 9abe03b..af50b80 100644 --- a/src/expression/flat.rs +++ b/src/expression/flat.rs @@ -381,7 +381,14 @@ mod detail { where T: DataType, { - Ok(idx > 0 && parser::is_operator_binary(op, &parsed_tokens[idx - 1])?) + parser::is_operator_binary( + op, + if idx > 0 { + Some(&parsed_tokens[idx - 1]) + } else { + None + }, + ) } type ExResultOption = ExResult>; @@ -467,7 +474,7 @@ mod detail { match p { Paren::Close => { let err_msg = - "a unary operator cannot on the left of a closing paren"; + "a unary operator cannot on the left of a closing paren or comma"; return Err(ExError::new(err_msg)); } Paren::Open => unary_stack.push((idx_tkn, depth)), @@ -532,6 +539,15 @@ mod detail { } } } + let n_ops = flat_ops.len(); + let n_nodes = flat_nodes.len(); + if n_ops + 1 != n_nodes { + Err(exerr!( + "we have {} ops and {} node. we always need one more node than op.", + n_ops, + n_nodes + ))? + } let indices = prioritized_indices_flat(&flat_ops, &flat_nodes); Ok(FlatEx { nodes: flat_nodes, diff --git a/src/operators.rs b/src/operators.rs index 36fe36f..8847450 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -351,6 +351,14 @@ impl MakeOperators for FloatOpsFactory { }, |a| -a, ), + Operator::make_bin( + "atan2", + BinOp { + apply: |y, x| y.atan2(x), + prio: 0, + is_commutative: false, + }, + ), Operator::make_unary("abs", |a| a.abs()), Operator::make_unary("signum", |a| a.signum()), Operator::make_unary("sin", |a| a.sin()), diff --git a/src/parser.rs b/src/parser.rs index 8b28b8c..fa089ab 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -6,6 +6,7 @@ use lazy_static::lazy_static; use regex::Regex; use smallvec::SmallVec; use std::fmt::Debug; +use std::mem; #[derive(Debug, PartialEq, Eq)] pub enum Paren { @@ -13,13 +14,27 @@ pub enum Paren { Close, } -#[derive(Debug, PartialEq, Eq)] +#[derive(PartialEq, Eq)] pub enum ParsedToken<'a, T: DataType> { Num(T), Paren(Paren), Op(Operator<'a, T>), Var(&'a str), } +impl<'a, T> Debug for ParsedToken<'a, T> +where + T: DataType, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Num(x) => f.write_str(format!("{x:?}").as_str()), + Self::Op(op) => f.write_str(op.repr()), + Self::Paren(Paren::Open) => f.write_str("("), + Self::Paren(Paren::Close) => f.write_str(")"), + Self::Var(v) => f.write_str(v), + } + } +} /// Returns the index of the variable in the slice. Panics if not available! pub fn find_var_index(name: &str, parsed_vars: &[&str]) -> usize { @@ -35,12 +50,12 @@ pub fn find_var_index(name: &str, parsed_vars: &[&str]) -> usize { /// Disambiguates operators based on predecessor token. pub fn is_operator_binary<'a, T: DataType>( op: &Operator<'a, T>, - parsed_token_on_the_left: &ParsedToken<'a, T>, + parsed_token_on_the_left: Option<&ParsedToken<'a, T>>, ) -> ExResult { if op.has_bin() && !op.has_unary() { match parsed_token_on_the_left { - ParsedToken::Op(op_) => Err(exerr!( - "a binary operator cannot be on the right another operator, {:?} next to {:?}", + Some(ParsedToken::Op(op_)) => Err(exerr!( + "a binary operator cannot be on the right of another operator, {:?} next to {:?}", op, op_ )), @@ -48,9 +63,10 @@ pub fn is_operator_binary<'a, T: DataType>( } } else if op.has_bin() && op.has_unary() { Ok(match parsed_token_on_the_left { - ParsedToken::Num(_) | ParsedToken::Var(_) => true, - ParsedToken::Paren(p) => *p == Paren::Close, - ParsedToken::Op(_) => false, + Some(ParsedToken::Num(_)) | Some(ParsedToken::Var(_)) => true, + Some(ParsedToken::Paren(p)) => *p == Paren::Close, + Some(ParsedToken::Op(_)) => false, + None => false, }) } else { Ok(false) @@ -99,6 +115,33 @@ fn next_char_boundary(text: &str, start_idx: usize) -> usize { .expect("there has to be a char boundary somewhere") } +fn find_op_of_comma(parsed_tokens: &[ParsedToken]) -> Option +where + T: DataType, +{ + let paren_counter = parsed_tokens.iter().rev().scan(0, |state, pt| { + *state += match pt { + ParsedToken::Paren(Paren::Close) => -1, + ParsedToken::Paren(Paren::Open) => 1, + _ => 0, + }; + Some(*state) + }); + + let rev_idx = parsed_tokens + .iter() + .rev() + .zip(paren_counter) + .enumerate() + .find(|(_, (pt, paren_cnt))| { + matches!(pt, + ParsedToken::Op(_) if *paren_cnt == 1) + }) + .map(|(i, _)| i); + + rev_idx.map(|ridx| parsed_tokens.len() - 1 - ridx) +} + /// Parses tokens of a text with regexes and returns them as a vector /// /// # Arguments @@ -152,18 +195,41 @@ where }; let mut res: SmallVec<[_; N_NODES_ON_STACK]> = SmallVec::new(); let mut cur_byte_offset = 0usize; + let mut close_additional_paren = false; + let mut open_paren_count = 0; for (i, c) in text.char_indices() { if c == ' ' && i == cur_byte_offset { cur_byte_offset += 1; } else if i == cur_byte_offset && cur_byte_offset < text.len() { let text_rest = &text[cur_byte_offset..]; let cur_byte_offset_tmp = cur_byte_offset; - let next_parsed_token = if c == '(' { + if c == '(' { cur_byte_offset += 1; - ParsedToken::::Paren(Paren::Open) + res.push(ParsedToken::::Paren(Paren::Open)); + open_paren_count += 1; } else if c == ')' { cur_byte_offset += 1; - ParsedToken::::Paren(Paren::Close) + open_paren_count -= 1; + res.push(ParsedToken::::Paren(Paren::Close)); + if close_additional_paren && open_paren_count == 0 { + res.push(ParsedToken::Paren(Paren::Close)); + close_additional_paren = false; + } + } else if c == ',' { + // this is for binary operators with function call syntax. + // we simply replace op(a,b) by ((a)op(b)) where the outer parens + // are added to increase the priority as expected from the function + // call syntax + cur_byte_offset += 1; + let op_idx = find_op_of_comma(&res).ok_or_else(|| { + exerr!("need operator for comma, could be missing operator or paren mismatch",) + })?; + let op_at_comma = mem::replace(&mut res[op_idx], ParsedToken::Paren(Paren::Open)); + close_additional_paren = true; + open_paren_count = 1; + res.push(ParsedToken::Paren(Paren::Close)); + res.push(op_at_comma); + res.push(ParsedToken::Paren(Paren::Open)); } else if c == '{' { let n_count = text_rest .chars() @@ -172,31 +238,42 @@ where .sum(); let var_name = &text_rest[1..n_count]; cur_byte_offset += n_count + 1; - ParsedToken::::Var(var_name) + res.push(ParsedToken::Var(var_name)); + if close_additional_paren && open_paren_count == 0 { + res.push(ParsedToken::Paren(Paren::Close)); + close_additional_paren = false; + } } else if let Some(num_str) = is_numeric(text_rest) { let n_bytes = num_str.len(); cur_byte_offset += n_bytes; - ParsedToken::::Num( + res.push(ParsedToken::::Num( num_str .parse::() .map_err(|e| exerr!("could not parse '{}', {:?}", num_str, e))?, - ) + )); + if close_additional_paren && open_paren_count == 0 { + res.push(ParsedToken::Paren(Paren::Close)); + close_additional_paren = false; + } } else if let Some(op) = find_ops(cur_byte_offset_tmp) { let n_bytes = op.repr().len(); cur_byte_offset += n_bytes; - match op.constant() { + res.push(match op.constant() { Some(constant) => ParsedToken::::Num(constant), None => ParsedToken::::Op((*op).clone()), - } + }); } else if let Some(var_str) = RE_VAR_NAME.find(text_rest) { let var_str = var_str.as_str(); let n_bytes = var_str.len(); cur_byte_offset += n_bytes; - ParsedToken::::Var(var_str) + res.push(ParsedToken::::Var(var_str)); + if close_additional_paren && open_paren_count == 0 { + res.push(ParsedToken::Paren(Paren::Close)); + close_additional_paren = false; + } } else { return Err(exerr!("don't know how to parse {}", text_rest)); - }; - res.push(next_parsed_token); + } } } Ok(res) @@ -210,23 +287,8 @@ fn make_err(msg: &str, left: &ParsedToken, right: &ParsedToken() -> [PairPreCondition<'a, T>; 9] { +fn make_pair_pre_conditions<'a, T: DataType>() -> [PairPreCondition<'a, T>; 7] { [ - PairPreCondition { - apply: |left, right| { - let num_var_str = - "a number/variable cannot be next to a number/variable, violated by "; - match (left, right) { - (ParsedToken::Num(_), ParsedToken::Var(_)) - | (ParsedToken::Var(_), ParsedToken::Num(_)) - | (ParsedToken::Num(_), ParsedToken::Num(_)) - | (ParsedToken::Var(_), ParsedToken::Var(_)) => { - make_err(num_var_str, left, right) - } - _ => Ok(()), - } - }, - }, PairPreCondition { apply: |left, right| match (left, right) { (ParsedToken::Paren(_p @ Paren::Close), ParsedToken::Num(_)) @@ -297,18 +359,6 @@ fn make_pair_pre_conditions<'a, T: DataType>() -> [PairPreCondition<'a, T>; 9] { } }, }, - PairPreCondition { - apply: |left, right| { - match (left, right) { - (ParsedToken::Paren(_p @ Paren::Open), ParsedToken::Op(op)) if !op.has_unary() => { - Err(exerr!( - "a binary operator cannot be on the right of an opening paren, violated by '{}'", - op.repr())) - } - _ => Ok(()), - } - }, - }, PairPreCondition { apply: |left, right| match (left, right) { ( @@ -451,12 +501,39 @@ fn test_preconditions() { test("12-(3-4)*2+ ((1/2)", "parentheses mismatch"); test(r"5\6", r"don't know how to parse \"); test(r"3.4.", r"don't know how to parse 3.4."); - test( - r"3. .4", - r"a number/variable cannot be next to a number/variable", - ); test( r"2sin({x})", r"number/variable cannot be on the left of a unary operator", ); } + +#[test] +fn test_find_comma_op() { + let pts = [ + ParsedToken::Paren(Paren::Close), + ParsedToken::Paren(Paren::Close), + ParsedToken::Op(Operator::make_bin( + "atan2", + crate::BinOp { + apply: |y: f64, x: f64| y.atan2(x), + prio: 1, + is_commutative: false, + }, + )), + ParsedToken::Paren(Paren::Open), + ]; + assert_eq!(Some(2), find_op_of_comma(&pts)); + let pts = [ + ParsedToken::Paren(Paren::Close), + ParsedToken::Op(Operator::make_bin( + "atan2", + crate::BinOp { + apply: |y: f64, x: f64| y.atan2(x), + prio: 1, + is_commutative: false, + }, + )), + ParsedToken::Paren(Paren::Open), + ]; + assert_eq!(Some(1), find_op_of_comma(&pts)); +} diff --git a/src/value.rs b/src/value.rs index 384f715..faf437b 100644 --- a/src/value.rs +++ b/src/value.rs @@ -122,6 +122,12 @@ where )), } } + pub fn to_float_val(self) -> Self { + match self.to_float() { + Ok(f) => Val::Float(f), + Err(e) => Val::Error(e), + } + } } impl From for Val @@ -340,7 +346,7 @@ where match (&a, &b) { (Val::Bool(a), Val::Bool(b)) => Val::Bool(*a || *b), _ => { - if a.clone() >= b.clone() { + if a >= b { a } else { b @@ -348,6 +354,22 @@ where } } } +fn atan2(a: Val, b: Val) -> Val +where + I: DataType + PrimInt + Signed, + ::Err: Debug, + F: DataType + Float, + ::Err: Debug, +{ + let a = a.to_float_val(); + let b = b.to_float_val(); + match (a, b) { + (Val::Float(a), Val::Float(b)) => Val::Float(a.atan2(b)), + (_, Val::Error(e)) => Val::Error(e), + (Val::Error(e), _) => Val::Error(e), + _ => Val::Error(exerr!("could not apply atan2 to",)), + } +} macro_rules! unary_match_name { ($name:ident, $scalar:ident, $(($unused_ops:expr, $variants:ident)),+) => { match $scalar { @@ -561,6 +583,14 @@ where is_commutative: false, }, ), + Operator::make_bin( + "atan2", + BinOp { + apply: atan2, + prio: 0, + is_commutative: false, + }, + ), Operator::make_bin( "%", BinOp { @@ -694,9 +724,12 @@ where Operator::make_bin( "else", BinOp { - apply: |res_of_if, v| match res_of_if { - Val::None => v, - _ => res_of_if, + apply: |res_of_if, v| { + println!("debug {res_of_if:?} {v:?}"); + match res_of_if { + Val::None => v, + _ => res_of_if, + } }, prio: 0, is_commutative: false, diff --git a/tests/core.rs b/tests/core.rs index 2172f06..68a83a2 100644 --- a/tests/core.rs +++ b/tests/core.rs @@ -11,6 +11,7 @@ use regex::Regex; use std::fs::{self, File}; use std::io::{self, BufRead}; use std::iter::repeat; +use utils::assert_float_eq_f64; #[cfg(test)] use std::{ @@ -418,8 +419,7 @@ fn test_variables() -> ExResult<()> { let expr = FlatEx::::parse(sut)?; utils::assert_float_eq_f64( expr.eval(&[2.5, 3.7]).unwrap(), - -(2.5f64.sqrt()) / (2.5f64.tanh() * 2.0) - + 2.0 / ((3.7f64.sinh() * 4.0).sin()).asin(), + -(2.5f64.sqrt()) / (2.5f64.tanh() * 2.0) + 2.0 / ((3.7f64.sinh() * 4.0).sin()).asin(), ); let sut = "asin(sin(x)) + acos(cos(x)) + atan(tan(x))"; @@ -973,3 +973,37 @@ fn test_string_ops() { "xyabcMINUS".to_string() ); } + +#[test] +fn test_binary_function_style() { + use std::fmt::Debug; + fn test(s: &str, vars: &[f64], reference: f64) { + println!("testing {s}"); + fn test_<'a, EX: Express<'a, f64> + Debug>(s: &'a str, vars: &[f64], reference: f64) { + let expr = EX::parse(s).unwrap(); + assert_float_eq_f64(expr.eval(vars).unwrap(), reference); + } + println!("flatex..."); + test_::>(s, vars, reference); + println!("deepex..."); + test_::>(s, vars, reference); + } + test( + "atan2(0.2/y, x)", + &[1.2, 2.1], + (0.2 / 2.1_f64).atan2(1.2_f64), + ); + test("+ (1, -2) / 2", &[], -0.5); + test("/ 1 2 * 3", &[], 1.5); + test("atan2(1, 2) * 3", &[], 1.0f64.atan2(2.0) * 3.0); + test( + "2 + atan2(1, x / 2) * 3", + &[1.0], + 2.0 + 1.0f64.atan2(0.5) * 3.0, + ); + test( + "sin(atan2(1, x / 2)) * 3", + &[1.0], + (1.0f64.atan2(0.5)).sin() * 3.0, + ); +} diff --git a/tests/value.rs b/tests/value.rs index db15e0b..e19db78 100644 --- a/tests/value.rs +++ b/tests/value.rs @@ -150,52 +150,89 @@ fn test_to() -> ExResult<()> { Ok(()) } #[cfg(feature = "value")] +#[cfg(test)] +use exmex::{DeepEx, ValMatcher, ValOpsFactory}; +#[cfg(feature = "value")] +#[cfg(test)] +type Fx = FlatExVal; +#[cfg(feature = "value")] +#[cfg(test)] +type Dx<'a> = DeepEx<'a, Val, ValOpsFactory, ValMatcher>; +#[cfg(feature = "value")] #[test] fn test_no_vars() -> ExResult<()> { fn test_int(s: &str, reference: i32) -> ExResult<()> { - println!("=== testing\n{}", s); - let res = exmex::parse_val::(s)?.eval(&[])?.to_int(); - match res { - Ok(i) => { - assert_eq!(reference, i); - } - Err(e) => { - println!("{:?}", e); - unreachable!(); + fn test_<'a, EX>(s: &'a str, reference: i32) -> ExResult<()> + where + EX: Express<'a, Val>, + { + println!("=== testing\n{}", s); + let res = exmex::parse_val::(s)?.eval(&[])?.to_int(); + match res { + Ok(i) => { + assert_eq!(reference, i); + } + Err(e) => { + println!("{:?}", e); + unreachable!(); + } } + Ok(()) } - Ok(()) + test_::(s, reference)?; + test_::(s, reference) } fn test_float(s: &str, reference: f64) -> ExResult<()> { - println!("=== testing\n{}", s); - let expr = FlatExVal::::parse(s)?; - utils::assert_float_eq_f64(reference, expr.eval(&[])?.to_float()?); - Ok(()) + fn test_<'a, EX>(s: &'a str, reference: f64) -> ExResult<()> + where + EX: Express<'a, Val>, + { + println!("=== testing\n{}", s); + let expr = FlatExVal::::parse(s)?; + utils::assert_float_eq_f64(reference, expr.eval(&[])?.to_float()?); + Ok(()) + } + test_::(s, reference)?; + test_::(s, reference) } fn test_bool(s: &str, reference: bool) -> ExResult<()> { println!("=== testing\n{}", s); - let expr = FlatExVal::::parse(s)?; - assert_eq!(reference, expr.eval(&[])?.to_bool()?); - Ok(()) + fn test_<'a, EX>(s: &'a str, reference: bool) -> ExResult<()> + where + EX: Express<'a, Val>, + { + let expr = EX::parse(s)?; + assert_eq!(reference, expr.eval(&[])?.to_bool()?); + Ok(()) + } + test_::(s, reference)?; + test_::(s, reference) } fn test_error(s: &str) -> ExResult<()> { - let expr = FlatExVal::::parse(s); - match expr { - Ok(exp) => { - let v = exp.eval(&[])?; - match v { - Val::Error(e) => { - println!("found expected error {:?}", e); - Ok(()) + fn test_<'a, EX>(s: &'a str) -> ExResult<()> + where + EX: Express<'a, Val>, + { + let expr = EX::parse(s); + match expr { + Ok(exp) => { + let v = exp.eval(&[])?; + match v { + Val::Error(e) => { + println!("found expected error {:?}", e); + Ok(()) + } + _ => Err(exerr!("'{}' should fail but didn't", s)), } - _ => Err(exerr!("'{}' should fail but didn't", s)), } - } - Err(e) => { - println!("found expected error {:?}", e); - Ok(()) + Err(e) => { + println!("found expected error {:?}", e); + Ok(()) + } } } + test_::(&s)?; + test_::(&s) } fn test_none(s: &str) -> ExResult<()> { let expr = FlatExVal::::parse(s)?; @@ -204,6 +241,7 @@ fn test_no_vars() -> ExResult<()> { _ => Err(exerr!("'{}' should return none but didn't", s)), } } + test_error("if true else 2")?; test_int("1+2 if 1 > 0 else 2+4", 3)?; test_int("1+2 if 1 < 0 else 2+4", 6)?; test_error("929<<92")?; @@ -267,7 +305,6 @@ fn test_no_vars() -> ExResult<()> { test_bool("true == 1", false)?; test_bool("true else 2", true)?; test_int("1 else 2", 1)?; - test_error("if true else 2")?; test_none("2 if false")?; test_int("to_int(1)", 1)?; test_int("to_int(3.5)", 3)?; @@ -295,7 +332,7 @@ fn test_no_vars() -> ExResult<()> { "atanh(0.5)/asinh(-7.5)*acosh(2.3)", 0.5f64.atanh() / (-7.5f64).asinh() * 2.3f64.acosh(), )?; - + test_float("sin(atan2(1, 1.0 / 2.0))", (1.0f64.atan2(0.5)).sin())?; Ok(()) }