Skip to content

Commit

Permalink
Add validation for comparison operations
Browse files Browse the repository at this point in the history
  • Loading branch information
jpschorr committed Oct 9, 2024
1 parent ee329ce commit e2d1f49
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 155 deletions.
28 changes: 25 additions & 3 deletions partiql-eval/src/eval/eval_expr_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ pub(crate) trait ArgChecker: Debug {
typ: &PartiqlShape,
arg: Cow<'a, Value>,
) -> ArgCheckControlFlow<Value, Cow<'a, Value>>;

/// Validate all arguments.
fn validate_args(args: &[Cow<'_, Value>]) -> Result<(), Value> {
Ok(())
}
}

/// How to handle argument mismatch and `MISSING` propagation
Expand Down Expand Up @@ -273,7 +278,12 @@ impl<const STRICT: bool, const N: usize, E: ExecuteEvalExpr<N>, ArgC: ArgChecker
ControlFlow::Break(Missing)
};

match evaluate_args::<{ STRICT }, ArgC, _>(&self.args, |n| &self.types[n], bindings, ctx) {
match evaluate_and_validate_args::<{ STRICT }, ArgC, _>(
&self.args,
|n| &self.types[n],
bindings,
ctx,
) {
ControlFlow::Continue(result) => match result.try_into() {
Ok(a) => ControlFlow::Continue(a),
Err(args) => err_arg_count_mismatch(args),
Expand All @@ -283,7 +293,7 @@ impl<const STRICT: bool, const N: usize, E: ExecuteEvalExpr<N>, ArgC: ArgChecker
}
}

pub(crate) fn evaluate_args<
pub(crate) fn evaluate_and_validate_args<
'a,
'c,
't,
Expand Down Expand Up @@ -352,7 +362,19 @@ where
ControlFlow::Break(v)
} else {
// If `propagate` is `None`, then return result
ControlFlow::Continue(result)

match ArgC::validate_args(&result) {
Ok(_) => ControlFlow::Continue(result),
Err(value) => {
if STRICT {
// TODO better error messages
ctx.add_error(EvaluationError::IllegalState(
"Arguments failed validation".to_string(),
))
}
ControlFlow::Break(value)
}
}
}
}

Expand Down
11 changes: 9 additions & 2 deletions partiql-eval/src/eval/expr/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::eval::eval_expr_wrapper::{evaluate_args, DefaultArgChecker, PropagateMissing};
use crate::eval::eval_expr_wrapper::{
evaluate_and_validate_args, DefaultArgChecker, PropagateMissing,
};

use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr};
use crate::eval::EvalContext;
Expand Down Expand Up @@ -41,7 +43,12 @@ impl<const STRICT: bool> EvalExpr for EvalExprFnScalar<STRICT> {
{
type Check<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<true>>;
let typ = PartiqlShapeBuilder::init_or_get().new_struct(StructType::new_any());
match evaluate_args::<{ STRICT }, Check<STRICT>, _>(&self.args, |_| &typ, bindings, ctx) {
match evaluate_and_validate_args::<{ STRICT }, Check<STRICT>, _>(
&self.args,
|_| &typ,
bindings,
ctx,
) {
ControlFlow::Break(v) => Cow::Owned(v),
ControlFlow::Continue(args) => match self.plan.evaluate(&args, ctx.as_session()) {
Ok(v) => v,
Expand Down
42 changes: 37 additions & 5 deletions partiql-eval/src/eval/expr/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use partiql_types::{
Static, StructType,
};
use partiql_value::Value::{Boolean, Missing, Null};
use partiql_value::{BinaryAnd, EqualityValue, NullableEq, NullableOrd, Tuple, Value};
use partiql_value::{BinaryAnd, Comparable, EqualityValue, NullableEq, NullableOrd, Tuple, Value};

use std::borrow::{Borrow, Cow};
use std::fmt::{Debug, Formatter};
Expand Down Expand Up @@ -146,6 +146,31 @@ impl<const TARGET: bool, OnMissing: ArgShortCircuit> ArgChecker
}
}

#[derive(Debug)]
pub(crate) struct ComparisonArgChecker<const STRICT: bool, OnMissing: ArgShortCircuit> {
check: PhantomData<DefaultArgChecker<STRICT, OnMissing>>,
}

impl<const STRICT: bool, OnMissing: ArgShortCircuit> ArgChecker
for ComparisonArgChecker<STRICT, OnMissing>
{
#[inline]
fn arg_check<'a>(
typ: &PartiqlShape,
arg: Cow<'a, Value>,
) -> ArgCheckControlFlow<Value, Cow<'a, Value>> {
DefaultArgChecker::<{ STRICT }, OnMissing>::arg_check(typ, arg)
}

fn validate_args(args: &[Cow<'_, Value>]) -> Result<(), Value> {
if args.len() == 2 && args[0].is_comparable_to(&args[1]) {
Ok(())
} else {
Err(OnMissing::propagate())
}
}
}

impl BindEvalExpr for EvalOpBinary {
#[inline]
fn bind<const STRICT: bool>(
Expand All @@ -157,6 +182,7 @@ impl BindEvalExpr for EvalOpBinary {
type InCheck<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateNull<false>>;
type Check<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<true>>;
type EqCheck<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<false>>;
type CompCheck<const STRICT: bool> = ComparisonArgChecker<STRICT, PropagateMissing<true>>;
type MathCheck<const STRICT: bool> = DefaultArgChecker<STRICT, PropagateMissing<true>>;

macro_rules! create {
Expand All @@ -177,6 +203,12 @@ impl BindEvalExpr for EvalOpBinary {
};
}

macro_rules! comparison {
($f:expr) => {
create!(CompCheck<STRICT>, [type_dynamic!(), type_dynamic!()], $f)
};
}

macro_rules! math {
($f:expr) => {{
let nums = PartiqlShapeBuilder::init_or_get().any_of(type_numeric!());
Expand All @@ -195,10 +227,10 @@ impl BindEvalExpr for EvalOpBinary {
let wrap = EqualityValue::<false, Value>;
NullableEq::neq(&wrap(lhs), &wrap(rhs))
}),
EvalOpBinary::Gt => equality!(NullableOrd::gt),
EvalOpBinary::Gteq => equality!(NullableOrd::gteq),
EvalOpBinary::Lt => equality!(NullableOrd::lt),
EvalOpBinary::Lteq => equality!(NullableOrd::lteq),
EvalOpBinary::Gt => comparison!(NullableOrd::gt),
EvalOpBinary::Gteq => comparison!(NullableOrd::gteq),
EvalOpBinary::Lt => comparison!(NullableOrd::lt),
EvalOpBinary::Lteq => comparison!(NullableOrd::lteq),
EvalOpBinary::Add => math!(|lhs, rhs| lhs + rhs),
EvalOpBinary::Sub => math!(|lhs, rhs| lhs - rhs),
EvalOpBinary::Mul => math!(|lhs, rhs| lhs * rhs),
Expand Down
115 changes: 115 additions & 0 deletions partiql/tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use partiql_ast_passes::error::AstTransformationError;
use partiql_catalog::catalog::{Catalog, PartiqlCatalog};
use partiql_catalog::context::SystemContext;
use partiql_eval as eval;
use partiql_eval::env::basic::MapBindings;
use partiql_eval::error::{EvalErr, PlanErr};
use partiql_eval::eval::{BasicContext, EvalPlan, EvalResult, Evaluated};
use partiql_eval::plan::EvaluationMode;
use partiql_logical as logical;
use partiql_parser::{Parsed, ParserError, ParserResult};
use partiql_value::{DateTime, Value};
use std::error::Error;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum TestError<'a> {
#[error("Parse error: {0:?}")]
Parse(ParserError<'a>),
#[error("Lower error: {0:?}")]
Lower(AstTransformationError),
#[error("Plan error: {0:?}")]
Plan(PlanErr),
#[error("Evaluation error: {0:?}")]
Eval(EvalErr),
#[error("Other: {0:?}")]
Other(Box<dyn Error>),
}

impl<'a> From<ParserError<'a>> for TestError<'a> {
fn from(err: ParserError<'a>) -> Self {
TestError::Parse(err)
}
}

impl From<AstTransformationError> for TestError<'_> {
fn from(err: AstTransformationError) -> Self {
TestError::Lower(err)
}
}

impl From<PlanErr> for TestError<'_> {
fn from(err: PlanErr) -> Self {
TestError::Plan(err)
}
}

impl From<EvalErr> for TestError<'_> {
fn from(err: EvalErr) -> Self {
TestError::Eval(err)
}
}

impl From<Box<dyn Error>> for TestError<'_> {
fn from(err: Box<dyn Error>) -> Self {
TestError::Other(err)
}
}

#[track_caller]
#[inline]
pub fn parse(statement: &str) -> ParserResult<'_> {
partiql_parser::Parser::default().parse(statement)
}

#[track_caller]
#[inline]
pub fn lower(
catalog: &dyn Catalog,
parsed: &Parsed<'_>,
) -> Result<logical::LogicalPlan<logical::BindingsOp>, AstTransformationError> {
let planner = partiql_logical_planner::LogicalPlanner::new(catalog);
planner.lower(parsed)
}

#[track_caller]
#[inline]
pub fn compile(
mode: EvaluationMode,
catalog: &dyn Catalog,
logical: logical::LogicalPlan<logical::BindingsOp>,
) -> Result<EvalPlan, PlanErr> {
let mut planner = eval::plan::EvaluatorPlanner::new(mode, catalog);
planner.compile(&logical)
}

#[track_caller]
#[inline]
pub fn evaluate(mut plan: EvalPlan, bindings: MapBindings<Value>) -> EvalResult {
let sys = SystemContext {
now: DateTime::from_system_now_utc(),
};
let ctx = BasicContext::new(bindings, sys);
plan.execute_mut(&ctx)
}

#[track_caller]
#[inline]
pub fn eval_query_with_catalog<'a, 'b>(
statement: &'a str,
catalog: &'b dyn Catalog,
mode: EvaluationMode,
) -> Result<Evaluated, TestError<'a>> {
let parsed = parse(statement)?;
let lowered = lower(catalog, &parsed)?;
let bindings = Default::default();
let plan = compile(mode, catalog, lowered)?;
Ok(evaluate(plan, bindings)?)
}

#[track_caller]
#[inline]
pub fn eval_query(statement: &str, mode: EvaluationMode) -> Result<Evaluated, TestError<'_>> {
let catalog = PartiqlCatalog::default();
eval_query_with_catalog(statement, &catalog, mode)
}
54 changes: 54 additions & 0 deletions partiql/tests/comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use crate::common::{
compile, eval_query, eval_query_with_catalog, evaluate, lower, parse, TestError,
};
use assert_matches::assert_matches;
use partiql_catalog::catalog::{Catalog, PartiqlCatalog};
use partiql_catalog::extension::Extension;
use partiql_eval::eval::Evaluated;
use partiql_eval::plan::EvaluationMode;
use partiql_extension_value_functions::PartiqlValueFnExtension;
use partiql_value::Value;
use std::os::macos::raw::stat;

mod common;

#[track_caller]
#[inline]
pub fn eval<'a>(statement: &'a str) {
dbg!(&statement);
let res = eval_query(statement, EvaluationMode::Permissive);
assert_matches!(res, Ok(_));
let res = res.unwrap().result;
assert_matches!(res, Value::Missing);

let res = eval_query(statement, EvaluationMode::Strict);
assert_matches!(res, Err(_));
let err = res.unwrap_err();
assert_matches!(err, TestError::Eval(_));
}

#[track_caller]
#[inline]
pub fn eval_op<'a>(op: &'a str) {
eval(&format!("1 {op} 'foo'"))
}

#[test]
fn lt() {
eval_op("<")
}

#[test]
fn gt() {
eval_op(">")
}

#[test]
fn lte() {
eval_op("<=")
}

#[test]
fn gte() {
eval_op(">=")
}
Loading

0 comments on commit e2d1f49

Please sign in to comment.