Skip to content

Commit

Permalink
Merge pull request #22 from vic1707/HOF
Browse files Browse the repository at this point in the history
Hof
  • Loading branch information
vic1707 authored Nov 29, 2023
2 parents 608670c + e09b5ed commit e321698
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 40 deletions.
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,27 @@ fn double(x: f64) -> f64 {
x * 2.0
}

const DOUBLE: Function = Function::new("double", move |args| double(args[0]), Some(1));
const DOUBLE: Function = Function::new_static("double", move |args| double(args[0]), Some(1));
// or with the macro (will do an automatic wrapping)
const DOUBLE_MACRO: Function = xprs_fn!("double", double, 1);

fn variadic_sum(args: &[f64]) -> f64 {
args.iter().sum()
}

const SUM: Function = Function::new("sum", variadic_sum, None);
const SUM: Function = Function::new_static("sum", variadic_sum, None);
// or with the macro (no wrapping is done for variadic functions)
const SUM_MACRO: Function = xprs_fn!("sum", variadic_sum);

// if a functions captures a variable (cannot be coerced to a static function)
const X: f64 = 42.0;
fn show_capture() {
let captures = |arg: f64| { X + arg };

let CAPTURES: Function = Function::new_dyn("captures", move |args| captures(args[0]), Some(1));
// or with the macro (will do an automatic wrapping)
let CAPTURES_MACRO: Function = xprs_fn!("captures", dyn captures, 1);
}
```

To use a [`Context`] and a [`Parser`] you can do the following:
Expand Down Expand Up @@ -224,7 +234,23 @@ Note2: `sum` and `mean` can take any number of arguments (if none, returns `0` a

## Higher order functions

TODO: Coming soon
You can define functions in a context based on a previously parsed expression.

```rust
use xprs::{xprs_fn, Context, Parser, Xprs};

fn main() {
let xprs_hof = Xprs::try_from("2x + y").unwrap();
let fn_hof = xprs_hof.bind2("x", "y").unwrap();
let hof = xprs_fn!("hof", dyn fn_hof, 2);
let ctx = Context::default().with_fn(hof);
let parser = Parser::new_with_ctx(ctx);

let xprs = parser.parse("hof(2, 3)").unwrap();

println!("hof(2, 3) = {}", xprs.eval_no_vars().unwrap());
}
```

These examples and others can be found in the [examples](./examples) directory.

Expand Down
5 changes: 3 additions & 2 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::{HashMap, HashSet};
use crate::token::Function;

/// Represents a symbol in the context.
#[derive(Debug, PartialEq, PartialOrd, Clone, Copy)]
#[derive(Debug, PartialEq, PartialOrd, Clone)]
#[non_exhaustive]
pub enum Symbol {
/// A variable.
Expand Down Expand Up @@ -38,7 +38,8 @@ impl From<Function> for Symbol {
/// let mut context = Context::default()
/// .with_expected_vars(["y"].into())
/// .with_var("x", 42.0)
/// .with_fn(sin_xprs_func);
/// // clone because assert_eq! is used later
/// .with_fn(sin_xprs_func.clone());
///
/// let x_var = context.get("x");
/// assert_eq!(x_var, Some(&Symbol::Variable(42.0)));
Expand Down
32 changes: 29 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,27 @@
//! x * 2.0
//! }
//!
//! const DOUBLE: Function = Function::new("double", move |args| double(args[0]), Some(1));
//! const DOUBLE: Function = Function::new_static("double", move |args| double(args[0]), Some(1));
//! // or with the macro (will do an automatic wrapping)
//! const DOUBLE_MACRO: Function = xprs_fn!("double", double, 1);
//!
//! fn variadic_sum(args: &[f64]) -> f64 {
//! args.iter().sum()
//! }
//!
//! const SUM: Function = Function::new("sum", variadic_sum, None);
//! const SUM: Function = Function::new_static("sum", variadic_sum, None);
//! // or with the macro (no wrapping is done for variadic functions)
//! const SUM_MACRO: Function = xprs_fn!("sum", variadic_sum);
//!
//! // if a functions captures a variable (cannot be coerced to a static function)
//! const X: f64 = 42.0;
//! fn show_capture() {
//! let captures = |arg: f64| { X + arg };
//!
//! let CAPTURES: Function = Function::new_dyn("captures", move |args| captures(args[0]), Some(1));
//! // or with the macro (will do an automatic wrapping)
//! let CAPTURES_MACRO: Function = xprs_fn!("captures", dyn captures, 1);
//! }
//! ```
//!
//! To use a [`Context`] and a [`Parser`] you can do the following:
Expand Down Expand Up @@ -224,7 +234,23 @@
//!
//! ## Higher order functions
//!
//! TODO: Coming soon
//! You can define functions in a context based on a previously parsed expression.
//!
//! ```rust
//! use xprs::{xprs_fn, Context, Parser, Xprs};
//!
//! fn main() {
//! let xprs_hof = Xprs::try_from("2x + y").unwrap();
//! let fn_hof = xprs_hof.bind2("x", "y").unwrap();
//! let hof = xprs_fn!("hof", dyn fn_hof, 2);
//! let ctx = Context::default().with_fn(hof);
//! let parser = Parser::new_with_ctx(ctx);
//!
//! let xprs = parser.parse("hof(2, 3)").unwrap();
//!
//! println!("hof(2, 3) = {}", xprs.eval_no_vars().unwrap());
//! }
//! ```
//!
//! These examples and others can be found in the [examples](./examples) directory.
//!
Expand Down
2 changes: 1 addition & 1 deletion src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl<'input, 'ctx> ParserImpl<'input, 'ctx> {
let ident = self
.ctx
.get(name)
.copied()
.cloned()
.map_or_else(|| Identifier::from_str(name), Into::into);

let el = match ident {
Expand Down
8 changes: 4 additions & 4 deletions src/tests/issues/issue_15.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
#![allow(clippy::min_ident_chars)]
use core::f64::consts::E;
/* Crate imports */
use super::super::macros::assert_f64_eq;
use crate::Parser;

const ERROR_MARGIN: f64 = f64::EPSILON;

const VALID_TEST_CASES: [(&str, f64); 9] = [
("1", 1.0),
("1.2", 1.2),
Expand All @@ -31,8 +30,9 @@ fn parse_number() {
let result = parser.parse(input).unwrap().eval(&[].into());
assert!(result.is_ok(), "Should have parsed: '{input}'.");
let value = result.unwrap();
assert!(
(value - expected).abs() < ERROR_MARGIN,
assert_f64_eq!(
value,
expected,
"{input}\nExpected: {expected}, got: {value}"
);
}
Expand Down
10 changes: 10 additions & 0 deletions src/tests/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
macro_rules! assert_f64_eq {
($left:expr, $right:expr) => {
assert!(($left - $right).abs() < f64::EPSILON);
};
($left:expr, $right:expr, $($arg:tt)+) => {
assert!(($left - $right).abs() < f64::EPSILON, $($arg)+);
};
}

pub(crate) use assert_f64_eq;
1 change: 1 addition & 0 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/* Modules */
mod issues;
mod macros;
mod parser;
mod thread_safety;
mod xprs;
3 changes: 2 additions & 1 deletion src/tests/thread_safety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
context::{Context, Symbol},
element::{BinOp, Element, FunctionCall, UnOp},
parser::{ErrorKind, ParseError, Parser},
token::{Function, Identifier, Operator},
token::{FnPointer, Function, Identifier, Operator},
xprs::{BindError, EvalError, Xprs},
};

Expand All @@ -24,6 +24,7 @@ const fn test_thread_safety() {
is_sized_send_sync_unpin::<ParseError>();
is_sized_send_sync_unpin::<Parser>();
// token module
is_sized_send_sync_unpin::<FnPointer>();
is_sized_send_sync_unpin::<Function>();
is_sized_send_sync_unpin::<Identifier>();
is_sized_send_sync_unpin::<Operator>();
Expand Down
8 changes: 4 additions & 4 deletions src/tests/xprs/eval.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/* Built-in imports */
use std::collections::HashMap;
/* Crate imports */
use super::super::macros::assert_f64_eq;
use crate::Parser;

const ERROR_MARGIN: f64 = f64::EPSILON;

// shitty type because of clippy and default numeric fallback
// https://github.com/rust-lang/rust-clippy/issues/11535
type InputVarsResult = (&'static str, &'static [(&'static str, f64)], f64);
Expand Down Expand Up @@ -46,8 +45,9 @@ fn test_valid_eval() {
let var_map: HashMap<&str, f64> = vars.iter().copied().collect();
let xprs = parser.parse(input).unwrap();
let result = xprs.eval(&var_map).unwrap();
assert!(
(result - expected).abs() < ERROR_MARGIN,
assert_f64_eq!(
result,
expected,
"{input}\nExpected: {expected}, got: {result}"
);
}
Expand Down
30 changes: 30 additions & 0 deletions src/tests/xprs/hof.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Crate imports */
use super::super::macros::assert_f64_eq;
use crate::{xprs_fn, Context, Parser, Xprs};

#[test]
fn test_higher_order_functions() {
let xprs_hof = Xprs::try_from("2x + y").unwrap();
let fn_hof = xprs_hof.bind2("x", "y").unwrap();
let hof = xprs_fn!("hof", dyn fn_hof, 2);
let ctx = Context::default().with_fn(hof);
let parser = Parser::new_with_ctx(ctx);

let bare_use = parser.parse("hof(2, 3)").unwrap();
assert_f64_eq!(bare_use.eval_unchecked(&[].into()), 7.0_f64);

let invalid_use = parser.parse("hof(2)");
assert!(
invalid_use.is_err(),
"Expected error for invalid use of hof"
);

let complex_use = parser.parse("hof(2, 3) + 3 * hof(4, 5)").unwrap();
assert_f64_eq!(complex_use.eval_unchecked(&[].into()), 46.0_f64);

let nested_var_use = parser.parse("hof(x, hof(2, 3))").unwrap();
assert_f64_eq!(
nested_var_use.eval_unchecked(&[("x", 42.0_f64)].into()),
91.0_f64
);
}
1 change: 1 addition & 0 deletions src/tests/xprs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/* Modules */
mod eval;
mod hof;
mod simplify;
Loading

0 comments on commit e321698

Please sign in to comment.