diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index e805b0dea5..39afb6d41e 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -48,6 +48,7 @@ public static LogicalPlanOptimizer create() { new MergeFilterAndFilter(), new PushFilterUnderSort(), EvalPushDown.PUSH_DOWN_LIMIT, + EvalPushDown.PUSH_DOWN_SORT, /* * Phase 2: Transformations that rely on data source push down capability */ diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java index 17eaed0e8c..7df5be80a7 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EvalPushDown.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.planner.optimizer.pattern.Patterns.evalCapture; import static org.opensearch.sql.planner.optimizer.pattern.Patterns.limit; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.sort; import static org.opensearch.sql.planner.optimizer.rule.EvalPushDown.EvalPushDownBuilder.match; import com.facebook.presto.matching.Capture; @@ -14,13 +15,21 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.pattern.CapturePattern; import com.facebook.presto.matching.pattern.WithPattern; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.BiFunction; +import java.util.stream.Collectors; import lombok.Getter; import lombok.experimental.Accessors; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalSort; import org.opensearch.sql.planner.optimizer.Rule; /** @@ -42,6 +51,38 @@ public class EvalPushDown implements Rule { return logicalEval; }); + public static final Rule PUSH_DOWN_SORT = + match(sort(evalCapture())).apply(EvalPushDown::processSortAndEval); + + private static LogicalPlan processSortAndEval(LogicalSort sort, LogicalEval logicalEval) { + List child = logicalEval.getChild(); + Map evalExpressionMap = + logicalEval.getExpressions().stream() + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + List> sortList = sort.getSortList(); + List> newSortList = new ArrayList<>(); + for (Pair pair : sortList) { + /* + Narrow down the optimization to only support: + 1. The expression in sort and replaced expression are both ReferenceExpression. + 2. No internal reference in eval. + */ + if (pair.getRight() instanceof ReferenceExpression) { + ReferenceExpression ref = (ReferenceExpression) pair.getRight(); + Expression newExpr = evalExpressionMap.getOrDefault(ref, ref); + if (newExpr instanceof ReferenceExpression) { + ReferenceExpression newRef = (ReferenceExpression) newExpr; + if (!evalExpressionMap.containsKey(newRef)) { + newSortList.add(Pair.of(pair.getLeft(), newRef)); + } else return sort; + } else return sort; + } else return sort; + } + sort = new LogicalSort(child.getFirst(), newSortList); + logicalEval.replaceChildPlans(List.of(sort)); + return logicalEval; + } + private final Capture capture; @Accessors(fluent = true) diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index 20996503b4..b7a3fbe023 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -42,6 +42,7 @@ import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -368,6 +369,66 @@ void push_limit_through_eval_into_scan() { optimize(limit(eval(relation("schema", table), evalExpr), 10, 5))); } + /** Sort - Eval --> Eval - Sort. */ + @Test + void push_sort_under_eval() { + ReferenceExpression sortRef = DSL.ref("intV", INTEGER); + ReferenceExpression evalRef = DSL.ref("name1", INTEGER); + Pair evalExpr = Pair.of(evalRef, DSL.ref("name", STRING)); + Pair sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef); + assertEquals( + eval(sort(tableScanBuilder, sortExpr), evalExpr), + optimize(sort(eval(relation("schema", table), evalExpr), sortExpr))); + } + + /** Sort - Eval - Scan --> Eval - Scan. */ + @Test + void push_sort_through_eval_into_scan() { + when(tableScanBuilder.pushDownSort(any())).thenReturn(true); + Pair evalExpr = + Pair.of(DSL.ref("name1", STRING), DSL.ref("name", STRING)); + Pair sortExpr = + Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)); + assertEquals( + eval(tableScanBuilder, evalExpr), + optimize(sort(eval(relation("schema", table), evalExpr), sortExpr))); + } + + @Test + void no_push_sort_under_eval_if_sort_field_is_not_reference_expression() { + // don't push sort if sort field is not ReferenceExpression + ReferenceExpression evalRef = DSL.ref("name1", INTEGER); + Pair evalExpr = Pair.of(evalRef, DSL.ref("name", STRING)); + Expression nonRefExpr = DSL.add(DSL.ref("intV", INTEGER), DSL.literal(1)); + Pair sortExprWithNonRef = + Pair.of(Sort.SortOption.DEFAULT_ASC, nonRefExpr); + LogicalPlan originPlan = sort(eval(relation("schema", table), evalExpr), sortExprWithNonRef); + assertEquals(originPlan, optimize(originPlan)); + } + + @Test + void no_push_sort_under_eval_if_replaced_field_is_not_reference_expression() { + // don't push sort if replaced expr in eval is not ReferenceExpression + ReferenceExpression sortRef = DSL.ref("intV", INTEGER); + Pair sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef); + Expression nonRefExpr = DSL.add(DSL.ref("intV", INTEGER), DSL.literal(1)); + Pair evalExprWithNonRef = Pair.of(sortRef, nonRefExpr); + LogicalPlan originPlan = sort(eval(relation("schema", table), evalExprWithNonRef), sortExpr); + assertEquals(originPlan, optimize(originPlan)); + } + + @Test + void no_push_sort_under_eval_if_having_internal_reference() { + // don't push sort if there are internal reference in eval + ReferenceExpression sortRef = DSL.ref("intV", INTEGER); + ReferenceExpression evalRef = DSL.ref("name1", INTEGER); + Pair sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef); + Pair evalExpr = Pair.of(evalRef, DSL.ref("name", STRING)); + Pair evalExpr2 = Pair.of(sortRef, evalRef); + LogicalPlan originPlan = sort(eval(relation("schema", table), evalExpr, evalExpr2), sortExpr); + assertEquals(originPlan, optimize(originPlan)); + } + private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); return optimizer.optimize(plan); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index c6b21e1605..9cf699833d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -89,6 +89,19 @@ public void testLimitPushDownExplain() throws Exception { + "| fields ageMinus")); } + @Test + public void testSortPushDownThroughEvalExplain() throws Exception { + String expected = loadFromFile("expectedOutput/ppl/explain_sort_push_through_eval.json"); + + assertJsonEquals( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account" + + "| eval newAge = age" + + "| sort newAge" + + "| fields newAge")); + } + String loadFromFile(String filename) throws Exception { URI uri = Resources.getResource(filename).toURI(); return new String(Files.readAllBytes(Paths.get(uri))); diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push_through_eval.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push_through_eval.json new file mode 100644 index 0000000000..d3c46245f9 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_sort_push_through_eval.json @@ -0,0 +1,27 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[newAge]" + }, + "children": [ + { + "name": "EvalOperator", + "description": { + "expressions": { + "newAge": "age" + } + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)" + }, + "children": [] + } + ] + } + ] + } +}