Skip to content

Commit

Permalink
fix(interactive): support CaseWhen in GIE runtime (#3868)
Browse files Browse the repository at this point in the history
As titled. Support case when, e.g., 
```
 MATCH (post: POST)
        WITH post,
             CASE
               WHEN post.creationDate < 20120629020000000
         AND post.creationDate >= 20120601000000000 THEN 1
               ELSE 0
             END AS valid,
             CASE
               WHEN 20120601000000000 > post.creationDate THEN 1
               ELSE 0
             END AS inValid
RETURN post
```

Fixes #3736
  • Loading branch information
BingqingLyu authored Jun 7, 2024
1 parent 00735b7 commit 8bd72c3
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,38 @@ public static QueryContext get_ldbc_4_test() {
return new QueryContext(query, expected);
}

// minor diff with get_ldbc_4_test since in experiment store the date is in a different format
// (e.g., 20120629020000000)
public static QueryContext get_ldbc_4_test_exp() {
String query =
"MATCH (person:PERSON {id:"
+ " 10995116278874})-[:KNOWS]-(friend:PERSON)<-[:HASCREATOR]-(post:POST)-[:HASTAG]->(tag:"
+ " TAG)\n"
+ "WITH DISTINCT tag, post\n"
+ "WITH tag,\n"
+ " CASE\n"
+ " WHEN post.creationDate < 20120629020000000 AND post.creationDate >="
+ " 20120601000000000 THEN 1\n"
+ " ELSE 0\n"
+ " END AS valid,\n"
+ " CASE\n"
+ " WHEN 20120601000000000 > post.creationDate THEN 1\n"
+ " ELSE 0\n"
+ " END AS inValid\n"
+ "WITH tag, sum(valid) AS postCount, sum(inValid) AS inValidPostCount\n"
+ "WHERE postCount>0 AND inValidPostCount=0\n"
+ "\n"
+ "RETURN tag.name AS tagName, postCount\n"
+ "ORDER BY postCount DESC, tagName ASC\n"
+ "LIMIT 10;";
List<String> expected =
Arrays.asList(
"Record<{tagName: \"Norodom_Sihanouk\", postCount: 3}>",
"Record<{tagName: \"George_Clooney\", postCount: 1}>",
"Record<{tagName: \"Louis_Philippe_I\", postCount: 1}>");
return new QueryContext(query, expected);
}

public static QueryContext get_ldbc_6_test() {
String query =
"MATCH (person:PERSON"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ public void run_ldbc_4_test() {
Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString());
}

@Test
public void run_ldbc_4_test_exp() {
assumeTrue("pegasus".equals(System.getenv("ENGINE_TYPE")));
QueryContext testQuery = LdbcQueries.get_ldbc_4_test_exp();
Result result = session.run(testQuery.getQuery());
Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString());
}

@Test
public void run_ldbc_6_test() {
QueryContext testQuery = LdbcQueries.get_ldbc_6_test();
Expand Down
123 changes: 119 additions & 4 deletions interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use ir_common::expr_parse::to_suffix_expr;
use ir_common::generated::common as common_pb;
use ir_common::{NameOrId, ALL_KEY, ID_KEY, LABEL_KEY, LENGTH_KEY};

use super::eval_pred::PEvaluator;
use crate::apis::{Element, PropKey};
use crate::utils::expr::eval_pred::EvalPred;
use crate::utils::expr::{ExprEvalError, ExprEvalResult};
Expand Down Expand Up @@ -63,13 +64,51 @@ pub enum Function {
Extract(common_pb::extract::Interval),
}

#[derive(Debug)]
pub struct CaseWhen {
when_then_evals: Vec<(PEvaluator, Evaluator)>,
else_eval: Evaluator,
}

impl TryFrom<common_pb::Case> for CaseWhen {
type Error = ParsePbError;

fn try_from(case: common_pb::Case) -> Result<Self, Self::Error> {
let mut when_then_evals = Vec::with_capacity(case.when_then_expressions.len());
for when_then in &case.when_then_expressions {
let when = when_then
.when_expression
.as_ref()
.ok_or(ParsePbError::EmptyFieldError(format!("missing when expression {:?}", case)))?;
let then = when_then
.then_result_expression
.as_ref()
.ok_or(ParsePbError::EmptyFieldError(format!("missing then expression: {:?}", case)))?;
when_then_evals.push((PEvaluator::try_from(when.clone())?, Evaluator::try_from(then.clone())?));
}
let else_result_expression = case
.else_result_expression
.as_ref()
.ok_or(ParsePbError::EmptyFieldError(format!("missing else expression: {:?}", case)))?;
let else_eval = Evaluator::try_from(else_result_expression.clone())?;
Ok(Self { when_then_evals, else_eval })
}
}

/// A conditional expression for evaluating a casewhen. More conditional expressions can be added in the future, e.g., COALESCE(),NULLIF() etc.
#[derive(Debug)]
pub enum Conditional {
Case(CaseWhen),
}

/// An inner representation of `common_pb::ExprOpr` for one-shot translation of `common_pb::ExprOpr`.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) enum InnerOpr {
Logical(common_pb::Logical),
Arith(common_pb::Arithmetic),
Function(Function),
Operand(Operand),
Conditional(Conditional),
}

impl ToString for InnerOpr {
Expand All @@ -79,6 +118,7 @@ impl ToString for InnerOpr {
InnerOpr::Arith(arith) => format!("{:?}", arith),
InnerOpr::Operand(item) => format!("{:?}", item),
InnerOpr::Function(func) => format!("{:?}", func),
InnerOpr::Conditional(conditional) => format!("{:?}", conditional),
}
}
}
Expand Down Expand Up @@ -269,6 +309,22 @@ pub(crate) fn apply_logical<'a>(
}
}

pub(crate) fn apply_condition_expr<'a, E: Element, C: Context<E>>(
condition: &Conditional, context: Option<&C>,
) -> ExprEvalResult<Object> {
match condition {
Conditional::Case(case) => {
let else_expr = &case.else_eval;
for (when, then) in case.when_then_evals.iter() {
if when.eval_bool(context)? {
return then.eval(context);
}
}
return else_expr.eval(context);
}
}
}

// Private api
impl Evaluator {
/// Evaluate simple expression that contains less than three operators
Expand All @@ -281,7 +337,11 @@ impl Evaluator {
if self.suffix_tree.is_empty() {
Err(ExprEvalError::EmptyExpression)
} else if self.suffix_tree.len() == 1 {
_first.unwrap().eval(context)
if let InnerOpr::Conditional(case) = _first.unwrap() {
apply_condition_expr(case, context)
} else {
_first.unwrap().eval(context)
}
} else if self.suffix_tree.len() == 2 {
let first = _first.unwrap();
let second = _second.unwrap();
Expand Down Expand Up @@ -591,6 +651,7 @@ impl TryFrom<common_pb::ExprOpr> for InnerOpr {
Extract(extract) => Ok(Self::Function(Function::Extract(unsafe {
std::mem::transmute::<_, common_pb::extract::Interval>(extract.interval)
}))),
Case(case) => Ok(Self::Conditional(Conditional::Case(case.clone().try_into()?))),
_ => Ok(Self::Operand(unit.clone().try_into()?)),
}
} else {
Expand Down Expand Up @@ -974,7 +1035,7 @@ mod tests {
#[test]
fn test_eval_variable() {
// [v0: id = 1, label = 9, age = 31, name = John, birthday = 19900416, hobbies = [football, guitar]]
// [v1: id = 2, label = 11, age = 26, name = Jimmy, birthday = 19950816]
// [v1: id = 2, label = 11, age = 26, name = Nancy, birthday = 19950816]
let ctxt = prepare_context();
let cases: Vec<&str> = vec![
"@0.~id", // 1
Expand Down Expand Up @@ -1102,7 +1163,7 @@ mod tests {
#[test]
fn test_eval_is_null() {
// [v0: id = 1, label = 9, age = 31, name = John, birthday = 19900416, hobbies = [football, guitar]]
// [v1: id = 2, label = 11, age = 26, name = Jimmy, birthday = 19950816]
// [v1: id = 2, label = 11, age = 26, name = Nancy, birthday = 19950816]
let ctxt = prepare_context();
let cases: Vec<&str> = vec![
"isNull @0.hobbies", // false
Expand Down Expand Up @@ -1344,4 +1405,58 @@ mod tests {
assert_eq!(eval.eval::<(), NoneContext>(None).unwrap(), expected);
}
}

fn prepare_casewhen(when_then_exprs: Vec<(&str, &str)>, else_expr: &str) -> common_pb::Expression {
let mut when_then_expressions = vec![];
for (when_expr, then_expr) in when_then_exprs {
when_then_expressions.push(common_pb::case::WhenThen {
when_expression: Some(str_to_expr_pb(when_expr.to_string()).unwrap()),
then_result_expression: Some(str_to_expr_pb(then_expr.to_string()).unwrap()),
});
}
let case_when_opr = common_pb::ExprOpr {
node_type: None,
item: Some(common_pb::expr_opr::Item::Case(common_pb::Case {
when_then_expressions,
else_result_expression: Some(str_to_expr_pb(else_expr.to_string()).unwrap()),
})),
};
common_pb::Expression { operators: vec![case_when_opr] }
}

#[test]
fn test_eval_casewhen() {
// [v0: id = 1, label = 9, age = 31, name = John, birthday = 19900416, hobbies = [football, guitar]]
// [v1: id = 2, label = 11, age = 26, name = Nancy, birthday = 19950816]
let ctxt = prepare_context();
let cases = vec![
(vec![("@0.~id ==1", "1"), ("@0.~id == 2", "2")], "0"),
(vec![("@0.~id > 10", "1"), ("@0.~id<5", "2")], "0"),
(vec![("@0.~id < 10 && @0.~id>20", "true")], "false"),
(vec![("@0.~id < 10 || @0.~id>20", "true")], "false"),
(vec![("@0.~id < 10 && @0.~id>20", "1+2")], "4+5"),
(vec![("@0.~id < 10 || @0.~id>20", "1+2")], "4+5"),
(vec![("true", "@0.name")], "@1.name"),
(vec![("false", "@0.~name")], "@1.name"),
(vec![("isNull @0.hobbies", "true")], "false"),
(vec![("isNull @1.hobbies", "true")], "false"),
];
let expected: Vec<Object> = vec![
object!(1),
object!(2),
object!(false),
object!(true),
object!(9),
object!(3),
object!("John"),
object!("Nancy"),
object!(false),
object!(true),
];

for ((when_then_exprs, else_expr), expected) in cases.into_iter().zip(expected.into_iter()) {
let eval = Evaluator::try_from(prepare_casewhen(when_then_exprs, else_expr)).unwrap();
assert_eq!(eval.eval::<_, Vertices>(Some(&ctxt)).unwrap(), expected);
}
}
}

0 comments on commit 8bd72c3

Please sign in to comment.