diff --git a/Cargo.toml b/Cargo.toml index ee50372..a95b606 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cel-eval" -version = "0.1.1" +version = "0.1.2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.htmlž @@ -14,7 +14,11 @@ cel-parser = "0.7.1" uniffi = { version = "0.28" } serde = { version = "1.0", features = ["serde_derive"] } serde_json = { version = "1.0" } +async-trait = "0.1.81" +smol = "2.0.1" +[dev-dependencies] +tokio = { version = "^1.20", features = ["rt-multi-thread", "macros"] } [build-dependencies] uniffi = { version = "0.28", features = [ "build" ] } diff --git a/build.rs b/build.rs index 94726f9..085db43 100644 --- a/build.rs +++ b/build.rs @@ -1,3 +1,3 @@ -fn main(){ +fn main() { uniffi::generate_scaffolding("./src/cel.udl").unwrap(); -} \ No newline at end of file +} diff --git a/src/ast.md b/src/ast.md new file mode 100644 index 0000000..e69c5ff --- /dev/null +++ b/src/ast.md @@ -0,0 +1,102 @@ +## AST example + +For convenience, here is a JSON example of an AST that can be evaluated by the library. + +```json +{ + "type": "And", + "value": [ + { + "type": "Relation", + "value": [ + { + "type": "Arithmetic", + "value": [ + { + "type": "Atom", + "value": { + "type": "Int", + "value": 5 + } + }, + { + "type": "Add" + }, + { + "type": "Atom", + "value": { + "type": "Int", + "value": 3 + } + } + ] + }, + { + "type": "GreaterThan" + }, + { + "type": "Atom", + "value": { + "type": "Int", + "value": 7 + } + } + ] + }, + { + "type": "Relation", + "value": [ + { + "type": "FunctionCall", + "value": [ + { + "type": "Member", + "value": [ + { + "type": "Ident", + "value": "name" + }, + { + "type": "Attribute", + "value": "length" + } + ] + }, + null, + [] + ] + }, + { + "type": "In" + }, + { + "type": "List", + "value": [ + { + "type": "Atom", + "value": { + "type": "Int", + "value": 5 + } + }, + { + "type": "Atom", + "value": { + "type": "Int", + "value": 10 + } + }, + { + "type": "Atom", + "value": { + "type": "Int", + "value": 15 + } + } + ] + } + ] + } + ] +} +``` \ No newline at end of file diff --git a/src/ast.rs b/src/ast.rs index ed87253..9cd84c7 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,9 +1,9 @@ -use std::collections::HashMap; -use std::sync::Arc; -use cel_parser::{ArithmeticOp, Atom, Expression, Member, RelationOp, UnaryOp}; +use crate::models::{PassableMap, PassableValue}; use cel_parser::Member::{Attribute, Fields, Index}; +use cel_parser::{ArithmeticOp, Atom, Expression, Member, RelationOp, UnaryOp}; use serde::{Deserialize, Serialize}; -use crate::models::{PassableMap, PassableValue}; +use std::collections::HashMap; +use std::sync::Arc; #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub(crate) struct ASTExecutionContext { @@ -48,12 +48,20 @@ pub enum JSONUnaryOp { pub enum JSONExpression { Arithmetic(Box, JSONArithmeticOp, Box), Relation(Box, JSONRelationOp, Box), - Ternary(Box, Box, Box), + Ternary( + Box, + Box, + Box, + ), Or(Box, Box), And(Box, Box), Unary(JSONUnaryOp, Box), Member(Box, Box), - FunctionCall(Box, Option>, Vec), + FunctionCall( + Box, + Option>, + Vec, + ), List(Vec), Map(Vec<(JSONExpression, JSONExpression)>), Atom(JSONAtom), @@ -136,32 +144,31 @@ impl From for Expression { Box::new((*true_expr).into()), Box::new((*false_expr).into()), ), - JSONExpression::Or(left, right) => Expression::Or( - Box::new((*left).into()), - Box::new((*right).into()), - ), - JSONExpression::And(left, right) => Expression::And( - Box::new((*left).into()), - Box::new((*right).into()), - ), - JSONExpression::Unary(op, expr) => Expression::Unary( - op.into(), - Box::new((*expr).into()), - ), - JSONExpression::Member(expr, member) => Expression::Member( - Box::new((*expr).into()), - Box::new((*member).into()), - ), + JSONExpression::Or(left, right) => { + Expression::Or(Box::new((*left).into()), Box::new((*right).into())) + } + JSONExpression::And(left, right) => { + Expression::And(Box::new((*left).into()), Box::new((*right).into())) + } + JSONExpression::Unary(op, expr) => { + Expression::Unary(op.into(), Box::new((*expr).into())) + } + JSONExpression::Member(expr, member) => { + Expression::Member(Box::new((*expr).into()), Box::new((*member).into())) + } JSONExpression::FunctionCall(func, optional_expr, args) => Expression::FunctionCall( Box::new((*func).into()), optional_expr.map(|e| Box::new((*e).into())), args.into_iter().map(|e| e.into()).collect(), ), - JSONExpression::List(items) => Expression::List( - items.into_iter().map(|e| e.into()).collect(), - ), + JSONExpression::List(items) => { + Expression::List(items.into_iter().map(|e| e.into()).collect()) + } JSONExpression::Map(items) => Expression::Map( - items.into_iter().map(|(k, v)| (k.into(), v.into())).collect(), + items + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), ), JSONExpression::Atom(atom) => Expression::Atom(atom.into()), JSONExpression::Ident(s) => Expression::Ident(Arc::new(s)), @@ -175,7 +182,10 @@ impl From for Member { JSONMember::Attribute(s) => Attribute(Arc::new(s)), JSONMember::Index(expr) => Index(Box::new((*expr).into())), JSONMember::Fields(fields) => Fields( - fields.into_iter().map(|(k, v)| (Arc::new(k), v.into())).collect(), + fields + .into_iter() + .map(|(k, v)| (Arc::new(k), v.into())) + .collect(), ), } } @@ -213,32 +223,31 @@ impl From for JSONExpression { Box::new((*true_expr).into()), Box::new((*false_expr).into()), ), - Expression::Or(left, right) => JSONExpression::Or( - Box::new((*left).into()), - Box::new((*right).into()), - ), - Expression::And(left, right) => JSONExpression::And( - Box::new((*left).into()), - Box::new((*right).into()), - ), - Expression::Unary(op, expr) => JSONExpression::Unary( - op.into(), - Box::new((*expr).into()), - ), - Expression::Member(expr, member) => JSONExpression::Member( - Box::new((*expr).into()), - Box::new((*member).into()), - ), + Expression::Or(left, right) => { + JSONExpression::Or(Box::new((*left).into()), Box::new((*right).into())) + } + Expression::And(left, right) => { + JSONExpression::And(Box::new((*left).into()), Box::new((*right).into())) + } + Expression::Unary(op, expr) => { + JSONExpression::Unary(op.into(), Box::new((*expr).into())) + } + Expression::Member(expr, member) => { + JSONExpression::Member(Box::new((*expr).into()), Box::new((*member).into())) + } Expression::FunctionCall(func, optional_expr, args) => JSONExpression::FunctionCall( Box::new((*func).into()), optional_expr.map(|e| Box::new((*e).into())), args.into_iter().map(|e| e.into()).collect(), ), - Expression::List(items) => JSONExpression::List( - items.into_iter().map(|e| e.into()).collect(), - ), + Expression::List(items) => { + JSONExpression::List(items.into_iter().map(|e| e.into()).collect()) + } Expression::Map(items) => JSONExpression::Map( - items.into_iter().map(|(k, v)| (k.into(), v.into())).collect(), + items + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(), ), Expression::Atom(atom) => JSONExpression::Atom(atom.into()), Expression::Ident(s) => JSONExpression::Ident((*s).clone()), @@ -290,7 +299,10 @@ impl From for JSONMember { Attribute(s) => JSONMember::Attribute((*s).clone()), Index(expr) => JSONMember::Index(Box::new((*expr).into())), Fields(fields) => JSONMember::Fields( - fields.into_iter().map(|(k, v)| ((*k).clone(), v.into())).collect(), + fields + .into_iter() + .map(|(k, v)| ((*k).clone(), v.into())) + .collect(), ), } } @@ -310,62 +322,68 @@ impl From for JSONAtom { } } - #[cfg(test)] mod tests { use super::*; + use cel_interpreter::Program; + use cel_parser::parser::ExpressionParser; #[test] fn test_ast_serializing() { - // ((5 + 3) > 7) && (name.length() in [5, 10, 15]) - let expr = Expression::And( - Box::new(Expression::Relation( - Box::new(Expression::Arithmetic( - Box::new(Expression::Atom(Atom::Int(5))), - ArithmeticOp::Add, - Box::new(Expression::Atom(Atom::Int(3))) - )), - RelationOp::GreaterThan, - Box::new(Expression::Atom(Atom::Int(7))) + // ((5 + 3) > 7) && (name.length() in [5, 10, 15]) + let expr = Expression::And( + Box::new(Expression::Relation( + Box::new(Expression::Arithmetic( + Box::new(Expression::Atom(Atom::Int(5))), + ArithmeticOp::Add, + Box::new(Expression::Atom(Atom::Int(3))), )), - Box::new(Expression::Relation( - Box::new(Expression::FunctionCall( - Box::new(Expression::Member( - Box::new(Expression::Ident(Arc::new("name".to_string()))), - Box::new(Attribute(Arc::new("length".to_string()))) - )), - None, - vec![] + RelationOp::GreaterThan, + Box::new(Expression::Atom(Atom::Int(7))), + )), + Box::new(Expression::Relation( + Box::new(Expression::FunctionCall( + Box::new(Expression::Member( + Box::new(Expression::Ident(Arc::new("name".to_string()))), + Box::new(Attribute(Arc::new("length".to_string()))), )), - RelationOp::In, - Box::new(Expression::List(vec![ - Expression::Atom(Atom::Int(5)), - Expression::Atom(Atom::Int(10)), - Expression::Atom(Atom::Int(15)) - ])) - )) - ); + None, + vec![], + )), + RelationOp::In, + Box::new(Expression::List(vec![ + Expression::Atom(Atom::Int(5)), + Expression::Atom(Atom::Int(10)), + Expression::Atom(Atom::Int(15)), + ])), + )), + ); - // Convert to JSONExpression - let json_expr: JSONExpression = expr.clone().into(); + // Convert to JSONExpression + let json_expr: JSONExpression = expr.clone().into(); - // Serialize to JSON - let json_string = serde_json::to_string_pretty(&json_expr).unwrap(); + // Serialize to JSON + let json_string = serde_json::to_string_pretty(&json_expr).unwrap(); - println!("JSON representation:"); - println!("{}", json_string); + println!("JSON representation:"); + println!("{}", json_string); - // Deserialize back to JSONExpression - let deserialized_json_expr: JSONExpression = serde_json::from_str(&json_string).unwrap(); + let text = "platform.myMethod(\"test\") == platform.name && user.test == 1"; + let program = ExpressionParser::new().parse(text).unwrap(); + let program: JSONExpression = program.into(); + let serialized = serde_json::to_string_pretty(&program).unwrap(); + println!("-----------–\n\n\n{}--------------\n\n", serialized); + // Deserialize back to JSONExpression + let deserialized_json_expr: JSONExpression = serde_json::from_str(&json_string).unwrap(); - // Convert back to original Expression - let deserialized_expr: Expression = deserialized_json_expr.into(); + // Convert back to original Expression + let deserialized_expr: Expression = deserialized_json_expr.into(); - println!("\nDeserialized Expression:"); - println!("{:?}", deserialized_expr); + println!("\nDeserialized Expression:"); + println!("{:?}", deserialized_expr); - // Check if the original and deserialized expressions are equal - assert_eq!(expr, deserialized_expr); - println!("\nOriginal and deserialized expressions are equal!"); + // Check if the original and deserialized expressions are equal + assert_eq!(expr, deserialized_expr); + println!("\nOriginal and deserialized expressions are equal!"); } } diff --git a/src/cel.udl b/src/cel.udl index 234eca1..5111752 100644 --- a/src/cel.udl +++ b/src/cel.udl @@ -1,4 +1,6 @@ -callback interface HostContext { +[Trait] +interface HostContext { + [Async] string computed_property(string name, string args); }; diff --git a/src/lib.rs b/src/lib.rs index 3b72276..01a0f05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,19 @@ uniffi::include_scaffolding!("cel"); -mod models; mod ast; +mod models; -use std::collections::HashMap; +use crate::ast::{ASTExecutionContext, JSONExpression}; +use crate::models::PassableValue::Function; +use crate::models::{ExecutionContext, PassableMap, PassableValue}; +use crate::ExecutableType::{CompiledProgram, AST}; +use async_trait::async_trait; +use cel_interpreter::extractors::This; +use cel_interpreter::objects::{Key, Map, TryIntoValue}; use cel_interpreter::{Context, ExecutionError, Expression, FunctionContext, Program, Value}; +use std::collections::HashMap; use std::fmt; +use std::ops::Deref; use std::sync::{Arc, Mutex}; -use cel_interpreter::extractors::This; -use cel_interpreter::objects::{Key, Map, TryIntoValue}; -use crate::ast::{ASTExecutionContext, JSONExpression}; -use crate::ExecutableType::{AST, CompiledProgram}; -use crate::models::{ExecutionContext, PassableMap, PassableValue}; -use crate::models::PassableValue::{Function}; - /** * Host context trait that defines the methods that the host context should implement, @@ -20,8 +21,9 @@ use crate::models::PassableValue::{Function}; * CEL expression during evaluation, such as `platform.daysSinceEvent("event_name")` or similar. */ +#[async_trait] pub trait HostContext: Send + Sync { - fn computed_property(&self, name: String, args: String) -> String; + async fn computed_property(&self, name: String, args: String) -> String; } /** @@ -30,9 +32,15 @@ pub trait HostContext: Send + Sync { * @param host The host context to use for resolving properties * @return The result of the evaluation, either "true" or "false" */ -pub fn evaluate_ast_with_context(definition: String, host: Box) -> String { +pub fn evaluate_ast_with_context(definition: String, host: Arc) -> String { let data: ASTExecutionContext = serde_json::from_str(definition.as_str()).unwrap(); - execute_with(AST(data.expression.into()), data.variables, data.platform, host) + let host = host.clone(); + execute_with( + AST(data.expression.into()), + data.variables, + data.platform, + host, + ) } /** @@ -48,7 +56,6 @@ pub fn evaluate_ast(ast: String) -> String { res.to_string() } - /** * Evaluate a CEL expression with the given definition by compiling it first. * @param definition The definition of the expression, serialized as JSON. This defines the expression, the variables, and the platform properties. @@ -56,8 +63,8 @@ pub fn evaluate_ast(ast: String) -> String { * @return The result of the evaluation, either "true" or "false" */ -pub fn evaluate_with_context(definition: String, host: Box) -> String { - let data: Result = serde_json::from_str(definition.as_str()); +pub fn evaluate_with_context(definition: String, host: Arc) -> String { + let data: Result = serde_json::from_str(definition.as_str()); let data = match data { Ok(data) => data, Err(e) => { @@ -65,84 +72,134 @@ pub fn evaluate_with_context(definition: String, host: Box>>, - host: Box) -> String { +fn execute_with( + executable: ExecutableType, + variables: PassableMap, + platform: Option>>, + host: Arc, +) -> String { + let host = host.clone(); let host = Arc::new(Mutex::new(host)); let mut ctx = Context::default(); // Add predefined variables locally to the context - variables.map.iter().for_each(|it| { - ctx.add_variable(it.0.as_str(), it.1.to_cel()).unwrap() - }); + variables + .map + .iter() + .for_each(|it| ctx.add_variable(it.0.as_str(), it.1.to_cel()).unwrap()); // Add maybe function ctx.add_function("maybe", maybe); // This function is used to extract the value of a property from the host context // As UniFFi doesn't support recursive enums yet, we have to pass it in as a // JSON serialized string of a PassableValue from Host and deserialize it here - fn prop_for(name: Arc, - args: Option>, ctx: &Box) -> Option { + fn prop_for( + name: Arc, + args: Option>, + ctx: &Arc, + ) -> Option { // Get computed property - let val = ctx.computed_property(name.clone().to_string(), serde_json::to_string(&args).unwrap()); + let val = smol::block_on(async move { + let ctx = ctx.clone(); + + ctx.computed_property( + name.clone().to_string(), + serde_json::to_string(&args).unwrap(), + ) + .await + }); // Deserialize the value - let passable: Option = serde_json::from_str(val.as_str()) - .unwrap_or(None); + let passable: Option = serde_json::from_str(val.as_str()).unwrap_or(None); passable } - let platform = platform.unwrap().clone(); + let platform = platform.unwrap_or(HashMap::new()).clone(); // Create platform properties as a map of keys and function names - let platform_properties: HashMap = platform.iter().map(|it| { - let args = it.1.clone(); - let args = if args.is_empty() { None } else { Some(Box::new(PassableValue::List(args))) }; - let name = it.0.clone(); - (Key::String(Arc::new(name.clone())), Function(name, args).to_cel()) - }).collect(); + let platform_properties: HashMap = platform + .iter() + .map(|it| { + let args = it.1.clone(); + let args = if args.is_empty() { + None + } else { + Some(Box::new(PassableValue::List(args))) + }; + let name = it.0.clone(); + ( + Key::String(Arc::new(name.clone())), + Function(name, args).to_cel(), + ) + }) + .collect(); // Add the map to the platform object - ctx.add_variable("platform", Value::Map(Map { map: Arc::new(platform_properties) })).unwrap(); + ctx.add_variable( + "platform", + Value::Map(Map { + map: Arc::new(platform_properties), + }), + ) + .unwrap(); // Add those functions to the context for it in platform.iter() { let key = it.0.clone(); let host_clone = Arc::clone(&host); // Clone the Arc to pass into the closure let key_str = key.clone(); // Clone key for usage in the closure - ctx.add_function(key_str.as_str(), move |ftx: &FunctionContext| -> Result { - let fx = ftx.clone(); - let name = fx.name.clone(); // Move the name into the closure - let args = fx.args.clone(); // Clone the arguments - let host = host_clone.lock().unwrap(); // Lock the host for safe access - prop_for(name.clone(), Some(args.iter().map(|expression| - DisplayableValue(ftx.ptx.resolve(expression).unwrap()).to_passable()).collect()), &*host) - .map_or(Err(ExecutionError::UndeclaredReference(name)), |v| Ok(v.to_cel())) - }); + ctx.add_function( + key_str.as_str(), + move |ftx: &FunctionContext| -> Result { + let fx = ftx.clone(); + let name = fx.name.clone(); // Move the name into the closure + let args = fx.args.clone(); // Clone the arguments + let host = host_clone.lock().unwrap(); // Lock the host for safe access + prop_for( + name.clone(), + Some( + args.iter() + .map(|expression| { + DisplayableValue(ftx.ptx.resolve(expression).unwrap()).to_passable() + }) + .collect(), + ), + &*host, + ) + .map_or(Err(ExecutionError::UndeclaredReference(name)), |v| { + Ok(v.to_cel()) + }) + }, + ); } - let val = match executable { - ExecutableType::AST(ast) => &ctx.resolve(&ast), - ExecutableType::CompiledProgram(program) => &program.execute(&ctx) + AST(ast) => &ctx.resolve(&ast), + CompiledProgram(program) => &program.execute(&ctx), }; match val { @@ -166,11 +223,10 @@ pub fn maybe( return ftx.ptx.resolve(&left).or_else(|_| ftx.ptx.resolve(&right)); } +// Wrappers around CEL values used so that we can create extensions on them +pub struct DisplayableValue(Value); -// Wrappers around CEL values so we can create extensions on them -pub struct DisplayableValue(cel_interpreter::Value); - -pub struct DisplayableError(cel_interpreter::ExecutionError); +pub struct DisplayableError(ExecutionError); impl fmt::Display for DisplayableValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -181,25 +237,38 @@ impl fmt::Display for DisplayableValue { Value::String(s) => write!(f, "{}", s), // Add more variants as needed Value::UInt(i) => write!(f, "{}", i), - Value::Bytes(_) => { write!(f, "{}", "bytes go here") } + Value::Bytes(_) => { + write!(f, "{}", "bytes go here") + } Value::Bool(b) => write!(f, "{}", b), Value::Duration(d) => write!(f, "{}", d), Value::Timestamp(t) => write!(f, "{}", t), Value::Null => write!(f, "{}", "null"), Value::Function(name, _) => write!(f, "{}", name), Value::Map(map) => { - let res: HashMap = map.map.iter().map(|(k, v)| { - let key = DisplayableValue(k.try_into_value().unwrap().clone()).to_string(); - let value = DisplayableValue(v.clone()).to_string().replace("\\", ""); - (key, value) - }).collect(); + let res: HashMap = map + .map + .iter() + .map(|(k, v)| { + let key = DisplayableValue(k.try_into_value().unwrap().clone()).to_string(); + let value = DisplayableValue(v.clone()).to_string().replace("\\", ""); + (key, value) + }) + .collect(); let map = serde_json::to_string(&res).unwrap(); write!(f, "{}", map) } - Value::List(list) => write!(f, "{}", list.iter().map(|v| { - let key = DisplayableValue(v.clone()); - return key.to_string(); - }).collect::>().join(",\n ")), + Value::List(list) => write!( + f, + "{}", + list.iter() + .map(|v| { + let key = DisplayableValue(v.clone()); + return key.to_string(); + }) + .collect::>() + .join(",\n ") + ), } } } @@ -218,18 +287,23 @@ impl fmt::Display for DisplayableError { ExecutionError::MissingArgumentOrTarget => write!(f, "MissingArgumentOrTarget"), ExecutionError::ValuesNotComparable(_, _) => write!(f, "ValuesNotComparable"), ExecutionError::UnsupportedUnaryOperator(_, _) => write!(f, "UnsupportedUnaryOperator"), - ExecutionError::UnsupportedBinaryOperator(_, _, _) => write!(f, "UnsupportedBinaryOperator"), + ExecutionError::UnsupportedBinaryOperator(_, _, _) => { + write!(f, "UnsupportedBinaryOperator") + } ExecutionError::UnsupportedMapIndex(_) => write!(f, "UnsupportedMapIndex"), ExecutionError::UnsupportedListIndex(_) => write!(f, "UnsupportedListIndex"), ExecutionError::UnsupportedIndex(_, _) => write!(f, "UnsupportedIndex"), - ExecutionError::UnsupportedFunctionCallIdentifierType(_) => write!(f, "UnsupportedFunctionCallIdentifierType"), - ExecutionError::UnsupportedFieldsConstruction(_) => write!(f, "UnsupportedFieldsConstruction"), + ExecutionError::UnsupportedFunctionCallIdentifierType(_) => { + write!(f, "UnsupportedFunctionCallIdentifierType") + } + ExecutionError::UnsupportedFieldsConstruction(_) => { + write!(f, "UnsupportedFieldsConstruction") + } ExecutionError::FunctionError { .. } => write!(f, "FunctionError"), } } } - #[cfg(test)] mod tests { use super::*; @@ -238,18 +312,20 @@ mod tests { map: HashMap, } + #[async_trait] impl HostContext for TestContext { - fn computed_property(&self, name: String, args: String) -> String { + async fn computed_property(&self, name: String, args: String) -> String { self.map.get(&name).unwrap().to_string() } } - #[test] - fn test_variables() { - let ctx = Box::new(TestContext { - map: HashMap::new() + #[tokio::test] + async fn test_variables() { + let ctx = Arc::new(TestContext { + map: HashMap::new(), }); - let res = evaluate_with_context(r#" + let res = evaluate_with_context( + r#" { "variables": { "map" : { @@ -258,16 +334,20 @@ mod tests { "expression": "foo == 100" } - "#.to_string(), ctx); + "# + .to_string(), + ctx, + ); assert_eq!(res, "true"); } - #[test] - fn test_execution_with_ctx() { - let ctx = Box::new(TestContext { - map: HashMap::new() + #[tokio::test] + async fn test_execution_with_ctx() { + let ctx = Arc::new(TestContext { + map: HashMap::new(), }); - let res = evaluate_with_context(r#" + let res = evaluate_with_context( + r#" { "variables": { "map" : { @@ -277,17 +357,21 @@ mod tests { "expression": "foo + bar == 142" } - "#.to_string(), ctx); + "# + .to_string(), + ctx, + ); assert_eq!(res, "true"); } #[test] - fn test_custom_function_with_arg() { - let ctx = Box::new(TestContext { - map: HashMap::new() + fn test_unknown_function_with_arg_fails_with_undeclared_ref() { + let ctx = Arc::new(TestContext { + map: HashMap::new(), }); - let res = evaluate_with_context(r#" + let res = evaluate_with_context( + r#" { "variables": { "map" : { @@ -296,16 +380,20 @@ mod tests { "expression": "test_custom_func(foo) == 101" } - "#.to_string(), ctx); - assert_eq!(res, "true"); + "# + .to_string(), + ctx, + ); + assert_eq!(res, "UndeclaredReference"); } #[test] fn test_list_contains() { - let ctx = Box::new(TestContext { - map: HashMap::new() + let ctx = Arc::new(TestContext { + map: HashMap::new(), }); - let res = evaluate_with_context(r#" + let res = evaluate_with_context( + r#" { "variables": { "map" : { @@ -322,16 +410,20 @@ mod tests { "expression": "numbers.contains(2)" } - "#.to_string(), ctx); + "# + .to_string(), + ctx, + ); assert_eq!(res, "true"); } - #[test] - fn test_execution_with_map() { - let ctx = Box::new(TestContext { - map: HashMap::new() + #[tokio::test] + async fn test_execution_with_map() { + let ctx = Arc::new(TestContext { + map: HashMap::new(), }); - let res = evaluate_with_context(r#" + let res = evaluate_with_context( + r#" { "variables": { "map": { @@ -353,19 +445,26 @@ mod tests { "expression": "user.should_display == true && user.some_value > 12" } - "#.to_string(), ctx); + "# + .to_string(), + ctx, + ); println!("{}", res); assert_eq!(res, "true"); } - #[test] - fn test_execution_with_platform_reference() { + #[tokio::test] + async fn test_execution_with_platform_reference() { let days_since = PassableValue::UInt(7); let days_since = serde_json::to_string(&days_since).unwrap(); - let ctx = Box::new(TestContext { - map: [("daysSinceEvent".to_string(), days_since)].iter().cloned().collect() + let ctx = Arc::new(TestContext { + map: [("daysSinceEvent".to_string(), days_since)] + .iter() + .cloned() + .collect(), }); - let res = evaluate_with_context(r#" + let res = evaluate_with_context( + r#" { "variables": { "map": { @@ -392,9 +491,11 @@ mod tests { }, "expression": "platform.daysSinceEvent(\"test\") == user.some_value" } - "#.to_string(), ctx); + "# + .to_string(), + ctx, + ); println!("{}", res); assert_eq!(res, "true"); } - } diff --git a/src/models.md b/src/models.md new file mode 100644 index 0000000..c1708eb --- /dev/null +++ b/src/models.md @@ -0,0 +1,32 @@ +## Models + +An example of `ExecutionContext` JSON for convenience: + +```json + { + "variables": { + "map": { + "user": { + "type": "map", + "value": { + "should_display": { + "type": "bool", + "value": true + }, + "some_value": { + "type": "uint", + "value": 7 + } + } + } + } + }, + "platform" : { + "daysSinceEvent": [{ + "type": "string", + "value": "event_name" + }] + }, + "expression": "platform.daysSinceEvent(\"test\") == user.some_value" + } +``` \ No newline at end of file diff --git a/src/models.rs b/src/models.rs index 962d107..bebb923 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,9 +1,9 @@ -use std::collections::HashMap; -use std::sync::Arc; +use crate::DisplayableValue; use cel_interpreter::objects::{Key, Map}; use cel_interpreter::Value; use serde::{Deserialize, Serialize}; -use crate::{DisplayableValue}; +use std::collections::HashMap; +use std::sync::Arc; #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub(crate) struct ExecutionContext { @@ -12,9 +12,7 @@ pub(crate) struct ExecutionContext { pub(crate) platform: Option>>, } - - -#[derive(Serialize, Deserialize,Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct PassableMap { pub map: HashMap, } @@ -25,7 +23,7 @@ pub enum PassableValue { #[serde(rename = "list")] List(Vec), #[serde(rename = "map")] - Map(HashMap), + Map(HashMap), #[serde(rename = "function")] Function(String, Option>), #[serde(rename = "int")] @@ -50,7 +48,9 @@ impl PartialEq for PassableValue { match (self, other) { (PassableValue::Map(a), PassableValue::Map(b)) => a == b, (PassableValue::List(a), PassableValue::List(b)) => a == b, - (PassableValue::Function(a1, a2), PassableValue::Function(b1, b2)) => a1 == b1 && a2 == b2, + (PassableValue::Function(a1, a2), PassableValue::Function(b1, b2)) => { + a1 == b1 && a2 == b2 + } (PassableValue::Int(a), PassableValue::Int(b)) => a == b, (PassableValue::UInt(a), PassableValue::UInt(b)) => a == b, (PassableValue::Float(a), PassableValue::Float(b)) => a == b, @@ -79,23 +79,26 @@ impl PartialEq for PassableValue { } } - - impl PassableValue { pub fn to_cel(&self) -> Value { match self { PassableValue::List(list) => { let mapped_list: Vec = list.iter().map(|item| item.to_cel()).collect(); Value::List(Arc::new(mapped_list)) - }, + } PassableValue::Map(map) => { - let mapped_map = map.iter().map(|(k, v)| (Key::String(Arc::from(k.clone())), (*v).to_cel())).collect(); - Value::Map(Map { map: Arc::new(mapped_map) }) - }, + let mapped_map = map + .iter() + .map(|(k, v)| (Key::String(Arc::from(k.clone())), (*v).to_cel())) + .collect(); + Value::Map(Map { + map: Arc::new(mapped_map), + }) + } PassableValue::Function(name, arg) => { let mapped_arg = arg.as_ref().map(|arg| arg.to_cel()); Value::Function(Arc::from(name.clone()), mapped_arg.map(|v| Box::new(v))) - }, + } PassableValue::Int(i) => Value::Int(*i), PassableValue::UInt(u) => Value::UInt(*u), PassableValue::Float(f) => Value::Float(*f), @@ -113,31 +116,40 @@ fn key_to_string(key: Key) -> String { Key::String(s) => (*s).clone(), Key::Int(i) => i.to_string(), Key::Uint(u) => u.to_string(), - Key::Bool(b) => {b.to_string()} + Key::Bool(b) => b.to_string(), } - } impl DisplayableValue { pub fn to_passable(&self) -> PassableValue { match &self.0 { Value::List(list) => { - let mapped_list: Vec = list.iter().map(|item| - DisplayableValue(item.clone()).to_passable()).collect(); + let mapped_list: Vec = list + .iter() + .map(|item| DisplayableValue(item.clone()).to_passable()) + .collect(); PassableValue::List(mapped_list) - }, + } Value::Map(map) => { - let mapped_map: HashMap = map.map.iter().map(|(k, v)| (key_to_string(k.clone()), - DisplayableValue(v.clone()).to_passable())).collect(); + let mapped_map: HashMap = map + .map + .iter() + .map(|(k, v)| { + ( + key_to_string(k.clone()), + DisplayableValue(v.clone()).to_passable(), + ) + }) + .collect(); PassableValue::Map(mapped_map) - }, + } Value::Function(name, arg) => { let mapped_arg = arg.as_ref().map(|arg| { let arg = *arg.clone(); let arg = DisplayableValue(arg).to_passable(); Box::new(arg) }); - PassableValue::Function( (**name).clone(), mapped_arg) - }, + PassableValue::Function((**name).clone(), mapped_arg) + } Value::Int(i) => PassableValue::Int(*i), Value::UInt(u) => PassableValue::UInt(*u), Value::Float(f) => PassableValue::Float(*f),