diff --git a/partiql-eval/src/eval/eval_expr_wrapper.rs b/partiql-eval/src/eval/eval_expr_wrapper.rs index b88d3cc6..16c16052 100644 --- a/partiql-eval/src/eval/eval_expr_wrapper.rs +++ b/partiql-eval/src/eval/eval_expr_wrapper.rs @@ -101,6 +101,11 @@ pub(crate) trait ArgChecker: Debug { typ: &PartiqlShape, arg: Cow<'a, Value>, ) -> ArgCheckControlFlow>; + + /// Validate all arguments. + fn validate_args(args: &[Cow<'_, Value>]) -> Result<(), Value> { + Ok(()) + } } /// How to handle argument mismatch and `MISSING` propagation @@ -273,7 +278,12 @@ impl, ArgC: ArgChecker ControlFlow::Break(Missing) }; - match evaluate_args::<{ STRICT }, ArgC, _>(&self.args, |n| &self.types[n], bindings, ctx) { + match evaluate_and_validate_args::<{ STRICT }, ArgC, _>( + &self.args, + |n| &self.types[n], + bindings, + ctx, + ) { ControlFlow::Continue(result) => match result.try_into() { Ok(a) => ControlFlow::Continue(a), Err(args) => err_arg_count_mismatch(args), @@ -283,7 +293,7 @@ impl, ArgC: ArgChecker } } -pub(crate) fn evaluate_args< +pub(crate) fn evaluate_and_validate_args< 'a, 'c, 't, @@ -352,7 +362,19 @@ where ControlFlow::Break(v) } else { // If `propagate` is `None`, then return result - ControlFlow::Continue(result) + + match ArgC::validate_args(&result) { + Ok(_) => ControlFlow::Continue(result), + Err(value) => { + if STRICT { + // TODO better error messages + ctx.add_error(EvaluationError::IllegalState( + "Arguments failed validation".to_string(), + )) + } + ControlFlow::Break(value) + } + } } } diff --git a/partiql-eval/src/eval/expr/functions.rs b/partiql-eval/src/eval/expr/functions.rs index 1d6e77e4..485ddac0 100644 --- a/partiql-eval/src/eval/expr/functions.rs +++ b/partiql-eval/src/eval/expr/functions.rs @@ -1,4 +1,6 @@ -use crate::eval::eval_expr_wrapper::{evaluate_args, DefaultArgChecker, PropagateMissing}; +use crate::eval::eval_expr_wrapper::{ + evaluate_and_validate_args, DefaultArgChecker, PropagateMissing, +}; use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use crate::eval::EvalContext; @@ -41,7 +43,12 @@ impl EvalExpr for EvalExprFnScalar { { type Check = DefaultArgChecker>; let typ = PartiqlShapeBuilder::init_or_get().new_struct(StructType::new_any()); - match evaluate_args::<{ STRICT }, Check, _>(&self.args, |_| &typ, bindings, ctx) { + match evaluate_and_validate_args::<{ STRICT }, Check, _>( + &self.args, + |_| &typ, + bindings, + ctx, + ) { ControlFlow::Break(v) => Cow::Owned(v), ControlFlow::Continue(args) => match self.plan.evaluate(&args, ctx.as_session()) { Ok(v) => v, diff --git a/partiql-eval/src/eval/expr/operators.rs b/partiql-eval/src/eval/expr/operators.rs index 693bf695..ed975dc8 100644 --- a/partiql-eval/src/eval/expr/operators.rs +++ b/partiql-eval/src/eval/expr/operators.rs @@ -12,7 +12,7 @@ use partiql_types::{ Static, StructType, }; use partiql_value::Value::{Boolean, Missing, Null}; -use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; +use partiql_value::{BinaryAnd, Comparable, EqualityValue, NullableEq, NullableOrd, Tuple, Value}; use std::borrow::{Borrow, Cow}; use std::fmt::{Debug, Formatter}; @@ -146,6 +146,31 @@ impl ArgChecker } } +#[derive(Debug)] +pub(crate) struct ComparisonArgChecker { + check: PhantomData>, +} + +impl ArgChecker + for ComparisonArgChecker +{ + #[inline] + fn arg_check<'a>( + typ: &PartiqlShape, + arg: Cow<'a, Value>, + ) -> ArgCheckControlFlow> { + DefaultArgChecker::<{ STRICT }, OnMissing>::arg_check(typ, arg) + } + + fn validate_args(args: &[Cow<'_, Value>]) -> Result<(), Value> { + if args.len() == 2 && args[0].is_comparable_to(&args[1]) { + Ok(()) + } else { + Err(OnMissing::propagate()) + } + } +} + impl BindEvalExpr for EvalOpBinary { #[inline] fn bind( @@ -157,6 +182,7 @@ impl BindEvalExpr for EvalOpBinary { type InCheck = DefaultArgChecker>; type Check = DefaultArgChecker>; type EqCheck = DefaultArgChecker>; + type CompCheck = ComparisonArgChecker>; type MathCheck = DefaultArgChecker>; macro_rules! create { @@ -177,6 +203,12 @@ impl BindEvalExpr for EvalOpBinary { }; } + macro_rules! comparison { + ($f:expr) => { + create!(CompCheck, [type_dynamic!(), type_dynamic!()], $f) + }; + } + macro_rules! math { ($f:expr) => {{ let nums = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()); @@ -195,10 +227,10 @@ impl BindEvalExpr for EvalOpBinary { let wrap = EqualityValue::; NullableEq::neq(&wrap(lhs), &wrap(rhs)) }), - EvalOpBinary::Gt => equality!(NullableOrd::gt), - EvalOpBinary::Gteq => equality!(NullableOrd::gteq), - EvalOpBinary::Lt => equality!(NullableOrd::lt), - EvalOpBinary::Lteq => equality!(NullableOrd::lteq), + EvalOpBinary::Gt => comparison!(NullableOrd::gt), + EvalOpBinary::Gteq => comparison!(NullableOrd::gteq), + EvalOpBinary::Lt => comparison!(NullableOrd::lt), + EvalOpBinary::Lteq => comparison!(NullableOrd::lteq), EvalOpBinary::Add => math!(|lhs, rhs| lhs + rhs), EvalOpBinary::Sub => math!(|lhs, rhs| lhs - rhs), EvalOpBinary::Mul => math!(|lhs, rhs| lhs * rhs), diff --git a/partiql/tests/common.rs b/partiql/tests/common.rs new file mode 100644 index 00000000..c8f132a6 --- /dev/null +++ b/partiql/tests/common.rs @@ -0,0 +1,115 @@ +use partiql_ast_passes::error::AstTransformationError; +use partiql_catalog::catalog::{Catalog, PartiqlCatalog}; +use partiql_catalog::context::SystemContext; +use partiql_eval as eval; +use partiql_eval::env::basic::MapBindings; +use partiql_eval::error::{EvalErr, PlanErr}; +use partiql_eval::eval::{BasicContext, EvalPlan, EvalResult, Evaluated}; +use partiql_eval::plan::EvaluationMode; +use partiql_logical as logical; +use partiql_parser::{Parsed, ParserError, ParserResult}; +use partiql_value::{DateTime, Value}; +use std::error::Error; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TestError<'a> { + #[error("Parse error: {0:?}")] + Parse(ParserError<'a>), + #[error("Lower error: {0:?}")] + Lower(AstTransformationError), + #[error("Plan error: {0:?}")] + Plan(PlanErr), + #[error("Evaluation error: {0:?}")] + Eval(EvalErr), + #[error("Other: {0:?}")] + Other(Box), +} + +impl<'a> From> for TestError<'a> { + fn from(err: ParserError<'a>) -> Self { + TestError::Parse(err) + } +} + +impl From for TestError<'_> { + fn from(err: AstTransformationError) -> Self { + TestError::Lower(err) + } +} + +impl From for TestError<'_> { + fn from(err: PlanErr) -> Self { + TestError::Plan(err) + } +} + +impl From for TestError<'_> { + fn from(err: EvalErr) -> Self { + TestError::Eval(err) + } +} + +impl From> for TestError<'_> { + fn from(err: Box) -> Self { + TestError::Other(err) + } +} + +#[track_caller] +#[inline] +pub fn parse(statement: &str) -> ParserResult<'_> { + partiql_parser::Parser::default().parse(statement) +} + +#[track_caller] +#[inline] +pub fn lower( + catalog: &dyn Catalog, + parsed: &Parsed<'_>, +) -> Result, AstTransformationError> { + let planner = partiql_logical_planner::LogicalPlanner::new(catalog); + planner.lower(parsed) +} + +#[track_caller] +#[inline] +pub fn compile( + mode: EvaluationMode, + catalog: &dyn Catalog, + logical: logical::LogicalPlan, +) -> Result { + let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog); + planner.compile(&logical) +} + +#[track_caller] +#[inline] +pub fn evaluate(mut plan: EvalPlan, bindings: MapBindings) -> EvalResult { + let sys = SystemContext { + now: DateTime::from_system_now_utc(), + }; + let ctx = BasicContext::new(bindings, sys); + plan.execute_mut(&ctx) +} + +#[track_caller] +#[inline] +pub fn eval_query_with_catalog<'a, 'b>( + statement: &'a str, + catalog: &'b dyn Catalog, + mode: EvaluationMode, +) -> Result> { + let parsed = parse(statement)?; + let lowered = lower(catalog, &parsed)?; + let bindings = Default::default(); + let plan = compile(mode, catalog, lowered)?; + Ok(evaluate(plan, bindings)?) +} + +#[track_caller] +#[inline] +pub fn eval_query(statement: &str, mode: EvaluationMode) -> Result> { + let catalog = PartiqlCatalog::default(); + eval_query_with_catalog(statement, &catalog, mode) +} diff --git a/partiql/tests/comparisons.rs b/partiql/tests/comparisons.rs new file mode 100644 index 00000000..416c6775 --- /dev/null +++ b/partiql/tests/comparisons.rs @@ -0,0 +1,54 @@ +use crate::common::{ + compile, eval_query, eval_query_with_catalog, evaluate, lower, parse, TestError, +}; +use assert_matches::assert_matches; +use partiql_catalog::catalog::{Catalog, PartiqlCatalog}; +use partiql_catalog::extension::Extension; +use partiql_eval::eval::Evaluated; +use partiql_eval::plan::EvaluationMode; +use partiql_extension_value_functions::PartiqlValueFnExtension; +use partiql_value::Value; +use std::os::macos::raw::stat; + +mod common; + +#[track_caller] +#[inline] +pub fn eval<'a>(statement: &'a str) { + dbg!(&statement); + let res = eval_query(statement, EvaluationMode::Permissive); + assert_matches!(res, Ok(_)); + let res = res.unwrap().result; + assert_matches!(res, Value::Missing); + + let res = eval_query(statement, EvaluationMode::Strict); + assert_matches!(res, Err(_)); + let err = res.unwrap_err(); + assert_matches!(err, TestError::Eval(_)); +} + +#[track_caller] +#[inline] +pub fn eval_op<'a>(op: &'a str) { + eval(&format!("1 {op} 'foo'")) +} + +#[test] +fn lt() { + eval_op("<") +} + +#[test] +fn gt() { + eval_op(">") +} + +#[test] +fn lte() { + eval_op("<=") +} + +#[test] +fn gte() { + eval_op(">=") +} diff --git a/partiql/tests/extension_error.rs b/partiql/tests/extension_error.rs index 3b71f856..abbc7624 100644 --- a/partiql/tests/extension_error.rs +++ b/partiql/tests/extension_error.rs @@ -19,8 +19,10 @@ use partiql_eval::plan::EvaluationMode; use partiql_parser::{Parsed, ParserResult}; use partiql_value::{bag, tuple, DateTime, Value}; +use crate::common::{lower, parse, TestError}; use partiql_logical as logical; +mod common; #[derive(Debug)] pub struct UserCtxTestExtension {} @@ -115,21 +117,6 @@ impl Iterator for TestDataGen { Some(Err(Box::new(UserCtxError::Runtime))) } } -#[track_caller] -#[inline] -pub(crate) fn parse(statement: &str) -> ParserResult { - partiql_parser::Parser::default().parse(statement) -} - -#[track_caller] -#[inline] -pub(crate) fn lower( - catalog: &dyn Catalog, - parsed: &Parsed<'_>, -) -> partiql_logical::LogicalPlan { - let planner = partiql_logical_planner::LogicalPlanner::new(catalog); - planner.lower(parsed).expect("lower") -} #[track_caller] #[inline] @@ -156,7 +143,7 @@ pub(crate) fn evaluate( } #[test] -fn test_context_bad_args_permissive() { +fn test_context_bad_args_permissive() -> Result<(), TestError<'static>> { let query = "SELECT foo, bar from test_user_context(9) as data"; let mut catalog = PartiqlCatalog::default(); @@ -164,7 +151,7 @@ fn test_context_bad_args_permissive() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -178,9 +165,11 @@ fn test_context_bad_args_permissive() { assert!(out.is_ok()); assert_eq!(out.unwrap().result, bag!(tuple!()).into()); + + Ok(()) } #[test] -fn test_context_bad_args_strict() { +fn test_context_bad_args_strict() -> Result<(), TestError<'static>> { use assert_matches::assert_matches; let query = "SELECT foo, bar from test_user_context(9) as data"; @@ -189,7 +178,7 @@ fn test_context_bad_args_strict() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -202,10 +191,12 @@ fn test_context_bad_args_strict() { assert_matches!(err, EvaluationError::ExtensionResultError(err) => { assert_eq!(err.to_string(), "bad arguments") }); + + Ok(()) } #[test] -fn test_context_runtime_permissive() { +fn test_context_runtime_permissive() -> Result<(), TestError<'static>> { let query = "SELECT foo, bar from test_user_context('counter') as data"; let mut catalog = PartiqlCatalog::default(); @@ -213,7 +204,7 @@ fn test_context_runtime_permissive() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -227,10 +218,11 @@ fn test_context_runtime_permissive() { assert!(out.is_ok()); assert_eq!(out.unwrap().result, bag!(tuple!()).into()); + Ok(()) } #[test] -fn test_context_runtime_strict() { +fn test_context_runtime_strict() -> Result<(), TestError<'static>> { use assert_matches::assert_matches; let query = "SELECT foo, bar from test_user_context('counter') as data"; @@ -239,7 +231,7 @@ fn test_context_runtime_strict() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let ctx: [(String, &dyn Any); 0] = []; @@ -252,4 +244,6 @@ fn test_context_runtime_strict() { assert_matches!(err, EvaluationError::ExtensionResultError(err) => { assert_eq!(err.to_string(), "runtime error") }); + + Ok(()) } diff --git a/partiql/tests/pretty.rs b/partiql/tests/pretty.rs index cbdc045f..be260c9a 100644 --- a/partiql/tests/pretty.rs +++ b/partiql/tests/pretty.rs @@ -1,18 +1,15 @@ +use crate::common::parse; use itertools::Itertools; use partiql_ast::ast::{AstNode, TopLevelQuery}; use partiql_ast::pretty::ToPretty; use partiql_parser::ParserResult; -#[track_caller] -#[inline] -fn parse(statement: &str) -> ParserResult<'_> { - partiql_parser::Parser::default().parse(statement) -} +mod common; #[track_caller] #[inline] fn pretty_print_test(name: &str, statement: &str) { - let res = parse(statement); + let res = common::parse(statement); assert!(res.is_ok()); let res = res.unwrap(); diff --git a/partiql/tests/tuple_ops.rs b/partiql/tests/tuple_ops.rs index 37112401..db1efbf1 100644 --- a/partiql/tests/tuple_ops.rs +++ b/partiql/tests/tuple_ops.rs @@ -1,113 +1,22 @@ +use crate::common::{compile, eval_query_with_catalog, evaluate, lower, parse, TestError}; use assert_matches::assert_matches; -use partiql_ast_passes::error::AstTransformationError; -use partiql_catalog::catalog::{Catalog, PartiqlCatalog}; -use partiql_catalog::context::SystemContext; +use partiql_catalog::catalog::PartiqlCatalog; use partiql_catalog::extension::Extension; -use partiql_eval as eval; -use partiql_eval::env::basic::MapBindings; -use partiql_eval::error::{EvalErr, PlanErr}; -use partiql_eval::eval::{BasicContext, EvalPlan, EvalResult, Evaluated}; +use partiql_eval::eval::Evaluated; use partiql_eval::plan::EvaluationMode; use partiql_extension_value_functions::PartiqlValueFnExtension; -use partiql_logical as logical; -use partiql_parser::{Parsed, ParserError, ParserResult}; -use partiql_value::{DateTime, Value}; -use std::error::Error; -use thiserror::Error; +use partiql_value::Value; -#[derive(Error, Debug)] -enum TestError<'a> { - #[error("Parse error: {0:?}")] - Parse(ParserError<'a>), - #[error("Lower error: {0:?}")] - Lower(AstTransformationError), - #[error("Plan error: {0:?}")] - Plan(PlanErr), - #[error("Evaluation error: {0:?}")] - Eval(EvalErr), - #[error("Other: {0:?}")] - Other(Box), -} - -impl<'a> From> for TestError<'a> { - fn from(err: ParserError<'a>) -> Self { - TestError::Parse(err) - } -} - -impl From for TestError<'_> { - fn from(err: AstTransformationError) -> Self { - TestError::Lower(err) - } -} - -impl From for TestError<'_> { - fn from(err: PlanErr) -> Self { - TestError::Plan(err) - } -} - -impl From for TestError<'_> { - fn from(err: EvalErr) -> Self { - TestError::Eval(err) - } -} - -impl From> for TestError<'_> { - fn from(err: Box) -> Self { - TestError::Other(err) - } -} - -#[track_caller] -#[inline] -fn parse(statement: &str) -> ParserResult<'_> { - partiql_parser::Parser::default().parse(statement) -} - -#[track_caller] -#[inline] -fn lower( - catalog: &dyn Catalog, - parsed: &Parsed<'_>, -) -> Result, AstTransformationError> { - let planner = partiql_logical_planner::LogicalPlanner::new(catalog); - planner.lower(parsed) -} - -#[track_caller] -#[inline] -fn compile( - mode: EvaluationMode, - catalog: &dyn Catalog, - logical: logical::LogicalPlan, -) -> Result { - let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog); - planner.compile(&logical) -} - -#[track_caller] -#[inline] -fn evaluate(mut plan: EvalPlan, bindings: MapBindings) -> EvalResult { - let sys = SystemContext { - now: DateTime::from_system_now_utc(), - }; - let ctx = BasicContext::new(bindings, sys); - plan.execute_mut(&ctx) -} +mod common; #[track_caller] #[inline] -fn eval(statement: &str, mode: EvaluationMode) -> Result> { +pub fn eval(statement: &str, mode: EvaluationMode) -> Result> { let mut catalog = PartiqlCatalog::default(); let ext = PartiqlValueFnExtension::default(); ext.load(&mut catalog)?; - let parsed = parse(statement)?; - let lowered = lower(&catalog, &parsed)?; - let bindings = Default::default(); - let plan = compile(mode, &catalog, lowered)?; - Ok(evaluate(plan, bindings)?) + eval_query_with_catalog(statement, &catalog, mode) } #[test] diff --git a/partiql/tests/user_context.rs b/partiql/tests/user_context.rs index 079acfb4..b53e871c 100644 --- a/partiql/tests/user_context.rs +++ b/partiql/tests/user_context.rs @@ -19,8 +19,10 @@ use partiql_eval::plan::EvaluationMode; use partiql_parser::{Parsed, ParserResult}; use partiql_value::{bag, tuple, DateTime, Value}; +use crate::common::{lower, parse, TestError}; use partiql_logical as logical; +mod common; #[derive(Debug)] pub struct UserCtxTestExtension {} @@ -141,22 +143,6 @@ pub struct Counter { data: RefCell, } -#[track_caller] -#[inline] -pub(crate) fn parse(statement: &str) -> ParserResult { - partiql_parser::Parser::default().parse(statement) -} - -#[track_caller] -#[inline] -pub(crate) fn lower( - catalog: &dyn Catalog, - parsed: &Parsed<'_>, -) -> partiql_logical::LogicalPlan { - let planner = partiql_logical_planner::LogicalPlanner::new(catalog); - planner.lower(parsed).expect("lower") -} - #[track_caller] #[inline] pub(crate) fn evaluate( @@ -183,8 +169,9 @@ pub(crate) fn evaluate( Value::Missing } } + #[test] -fn test_context() { +fn test_context() -> Result<(), TestError<'static>> { let expected: Value = bag![ tuple![("foo", 1), ("bar", "id_1")], tuple![("foo", 0), ("bar", "id_2")], @@ -201,7 +188,7 @@ fn test_context() { ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); - let lowered = lower(&catalog, &parsed.expect("parse")); + let lowered = lower(&catalog, &parsed.expect("parse"))?; let bindings = Default::default(); let counter = Counter { @@ -213,4 +200,6 @@ fn test_context() { assert!(out.is_bag()); assert_eq!(&out, &expected); assert_eq!(*counter.data.borrow(), 0); + + Ok(()) }