Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push down sort through eval #2937

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static LogicalPlanOptimizer create() {
new MergeFilterAndFilter(),
new PushFilterUnderSort(),
EvalPushDown.PUSH_DOWN_LIMIT,
EvalPushDown.PUSH_DOWN_SORT,
/*
* Phase 2: Transformations that rely on data source push down capability
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,29 @@

import static org.opensearch.sql.planner.optimizer.pattern.Patterns.evalCapture;
import static org.opensearch.sql.planner.optimizer.pattern.Patterns.limit;
import static org.opensearch.sql.planner.optimizer.pattern.Patterns.sort;
import static org.opensearch.sql.planner.optimizer.rule.EvalPushDown.EvalPushDownBuilder.match;

import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.matching.pattern.CapturePattern;
import com.facebook.presto.matching.pattern.WithPattern;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import lombok.Getter;
import lombok.experimental.Accessors;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalSort;
import org.opensearch.sql.planner.optimizer.Rule;

/**
Expand All @@ -42,6 +51,38 @@ public class EvalPushDown<T extends LogicalPlan> implements Rule<T> {
return logicalEval;
});

public static final Rule<LogicalSort> PUSH_DOWN_SORT =
match(sort(evalCapture())).apply(EvalPushDown::processSortAndEval);

private static LogicalPlan processSortAndEval(LogicalSort sort, LogicalEval logicalEval) {
List<LogicalPlan> child = logicalEval.getChild();
Map<ReferenceExpression, Expression> evalExpressionMap =
logicalEval.getExpressions().stream()
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
List<Pair<SortOption, Expression>> sortList = sort.getSortList();
List<Pair<SortOption, Expression>> newSortList = new ArrayList<>();
for (Pair<SortOption, Expression> pair : sortList) {
/*
Narrow down the optimization to only support:
1. The expression in sort and replaced expression are both ReferenceExpression.
2. No internal reference in eval.
*/
if (pair.getRight() instanceof ReferenceExpression) {
ReferenceExpression ref = (ReferenceExpression) pair.getRight();
Expression newExpr = evalExpressionMap.getOrDefault(ref, ref);
if (newExpr instanceof ReferenceExpression) {
ReferenceExpression newRef = (ReferenceExpression) newExpr;
if (!evalExpressionMap.containsKey(newRef)) {
newSortList.add(Pair.of(pair.getLeft(), newRef));
} else return sort;
} else return sort;
} else return sort;
}
sort = new LogicalSort(child.getFirst(), newSortList);
logicalEval.replaceChildPlans(List.of(sort));
return logicalEval;
}

private final Capture<LogicalEval> capture;

@Accessors(fluent = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.mockito.Spy;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
Expand Down Expand Up @@ -368,6 +369,66 @@ void push_limit_through_eval_into_scan() {
optimize(limit(eval(relation("schema", table), evalExpr), 10, 5)));
}

/** Sort - Eval --> Eval - Sort. */
@Test
void push_sort_under_eval() {
ReferenceExpression sortRef = DSL.ref("intV", INTEGER);
ReferenceExpression evalRef = DSL.ref("name1", INTEGER);
Pair<ReferenceExpression, Expression> evalExpr = Pair.of(evalRef, DSL.ref("name", STRING));
Pair<SortOption, Expression> sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef);
assertEquals(
eval(sort(tableScanBuilder, sortExpr), evalExpr),
optimize(sort(eval(relation("schema", table), evalExpr), sortExpr)));
}

/** Sort - Eval - Scan --> Eval - Scan. */
@Test
void push_sort_through_eval_into_scan() {
when(tableScanBuilder.pushDownSort(any())).thenReturn(true);
Pair<ReferenceExpression, Expression> evalExpr =
Pair.of(DSL.ref("name1", STRING), DSL.ref("name", STRING));
Pair<SortOption, Expression> sortExpr =
Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER));
assertEquals(
eval(tableScanBuilder, evalExpr),
optimize(sort(eval(relation("schema", table), evalExpr), sortExpr)));
}

@Test
void no_push_sort_under_eval_if_sort_field_is_not_reference_expression() {
// don't push sort if sort field is not ReferenceExpression
ReferenceExpression evalRef = DSL.ref("name1", INTEGER);
Pair<ReferenceExpression, Expression> evalExpr = Pair.of(evalRef, DSL.ref("name", STRING));
Expression nonRefExpr = DSL.add(DSL.ref("intV", INTEGER), DSL.literal(1));
Pair<SortOption, Expression> sortExprWithNonRef =
Pair.of(Sort.SortOption.DEFAULT_ASC, nonRefExpr);
LogicalPlan originPlan = sort(eval(relation("schema", table), evalExpr), sortExprWithNonRef);
assertEquals(originPlan, optimize(originPlan));
}

@Test
void no_push_sort_under_eval_if_replaced_field_is_not_reference_expression() {
// don't push sort if replaced expr in eval is not ReferenceExpression
ReferenceExpression sortRef = DSL.ref("intV", INTEGER);
Pair<SortOption, Expression> sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef);
Expression nonRefExpr = DSL.add(DSL.ref("intV", INTEGER), DSL.literal(1));
Pair<ReferenceExpression, Expression> evalExprWithNonRef = Pair.of(sortRef, nonRefExpr);
LogicalPlan originPlan = sort(eval(relation("schema", table), evalExprWithNonRef), sortExpr);
assertEquals(originPlan, optimize(originPlan));
}

@Test
void no_push_sort_under_eval_if_having_internal_reference() {
// don't push sort if there are internal reference in eval
ReferenceExpression sortRef = DSL.ref("intV", INTEGER);
ReferenceExpression evalRef = DSL.ref("name1", INTEGER);
Pair<SortOption, Expression> sortExpr = Pair.of(Sort.SortOption.DEFAULT_ASC, sortRef);
Pair<ReferenceExpression, Expression> evalExpr = Pair.of(evalRef, DSL.ref("name", STRING));
Pair<ReferenceExpression, Expression> evalExpr2 = Pair.of(sortRef, evalRef);
LogicalPlan originPlan = sort(eval(relation("schema", table), evalExpr, evalExpr2), sortExpr);
assertEquals(originPlan, optimize(originPlan));
}

private LogicalPlan optimize(LogicalPlan plan) {
final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create();
return optimizer.optimize(plan);
Expand Down
13 changes: 13 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ public void testLimitPushDownExplain() throws Exception {
+ "| fields ageMinus"));
}

@Test
public void testSortPushDownThroughEvalExplain() throws Exception {
String expected = loadFromFile("expectedOutput/ppl/explain_sort_push_through_eval.json");

assertJsonEquals(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account"
+ "| eval newAge = age"
+ "| sort newAge"
+ "| fields newAge"));
}

String loadFromFile(String filename) throws Exception {
URI uri = Resources.getResource(filename).toURI();
return new String(Files.readAllBytes(Paths.get(uri)));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"root": {
"name": "ProjectOperator",
"description": {
"fields": "[newAge]"
},
"children": [
{
"name": "EvalOperator",
"description": {
"expressions": {
"newAge": "age"
}
},
"children": [
{
"name": "OpenSearchIndexScan",
"description": {
"request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, searchDone=false)"
},
"children": []
}
]
}
]
}
}
Loading