Skip to content

Commit

Permalink
Adding rule to merge simplify plan generated for union of constant va…
Browse files Browse the repository at this point in the history
…lues
  • Loading branch information
ankitdixit committed Sep 30, 2020
1 parent a9d5c4c commit b8b5441
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ public final class SystemSessionProperties
public static final String DYNAMIC_FILTERING_MAX_PER_DRIVER_SIZE = "dynamic_filtering_max_per_driver_size";
public static final String LEGACY_TYPE_COERCION_WARNING_ENABLED = "legacy_type_coercion_warning_enabled";
public static final String INLINE_SQL_FUNCTIONS = "inline_sql_functions";
public static final String OPTIMIZE_UNION_OVER_VALUES = "optimize_union_over_values";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -875,6 +876,11 @@ public SystemSessionProperties(
INLINE_SQL_FUNCTIONS,
"Inline SQL function definition at plan time",
featuresConfig.isInlineSqlFunctions(),
false),
booleanProperty(
OPTIMIZE_UNION_OVER_VALUES,
"Merge values nodes under union operator",
featuresConfig.isOptimizeUnionOverValues(),
false));
}

Expand Down Expand Up @@ -1481,4 +1487,9 @@ public static boolean isInlineSqlFunctions(Session session)
{
return session.getSystemProperty(INLINE_SQL_FUNCTIONS, Boolean.class);
}

public static boolean isOptimizeUnionOverValues(Session session)
{
return session.getSystemProperty(OPTIMIZE_UNION_OVER_VALUES, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ public class FeaturesConfig
private boolean optimizeNullsInJoin;
private boolean pushdownDereferenceEnabled;
private boolean inlineSqlFunctions = true;
private boolean optimizeUnionOverValues = true;

private String warnOnNoTableLayoutFilter = "";

Expand Down Expand Up @@ -1398,4 +1399,16 @@ public FeaturesConfig setInlineSqlFunctions(boolean inlineSqlFunctions)
this.inlineSqlFunctions = inlineSqlFunctions;
return this;
}

public boolean isOptimizeUnionOverValues()
{
return optimizeUnionOverValues;
}

@Config("optimize-union-over-values")
public FeaturesConfig setOptimizeUnionOverValues(boolean optimizeUnionOverValues)
{
this.optimizeUnionOverValues = optimizeUnionOverValues;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations;
import com.facebook.presto.sql.planner.iterative.rule.InlineProjections;
import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions;
import com.facebook.presto.sql.planner.iterative.rule.MergeConstantValuesUnderUnion;
import com.facebook.presto.sql.planner.iterative.rule.MergeFilters;
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct;
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort;
Expand Down Expand Up @@ -72,6 +73,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns;
import com.facebook.presto.sql.planner.iterative.rule.PruneWindowColumns;
import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
import com.facebook.presto.sql.planner.iterative.rule.PushConstantProjectIntoEmptyValuesNode;
import com.facebook.presto.sql.planner.iterative.rule.PushDownDereferences;
import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughMarkDistinct;
import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughOuterJoin;
Expand Down Expand Up @@ -250,6 +252,14 @@ public PlanOptimizers(
new PushProjectionThroughUnion(),
new PushProjectionThroughExchange()));

IterativeOptimizer mergeValuesUnderUnion = new IterativeOptimizer(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(
new PushConstantProjectIntoEmptyValuesNode(),
new MergeConstantValuesUnderUnion()));

IterativeOptimizer simplifyOptimizer = new IterativeOptimizer(
ruleStats,
statsCalculator,
Expand Down Expand Up @@ -404,6 +414,7 @@ public PlanOptimizers(
inlineProjections,
simplifyRowExpressionOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations
projectionPushDown,
mergeValuesUnderUnion,
new UnaliasSymbolReferences(metadata.getFunctionManager()), // Run again because predicate pushdown and projection pushdown might add more projections
new PruneUnreferencedOutputs(), // Make sure to run this before index join. Filtered projections may not have all the columns.
new IndexJoinOptimizer(metadata), // Run this after projections and filters have been fully simplified and pushed down
Expand Down Expand Up @@ -512,7 +523,7 @@ public PlanOptimizers(
.addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager).rules())
.add(new InlineProjections(metadata.getFunctionManager()))
.build()));

builder.add(mergeValuesUnderUnion);
if (!forceSingleNode) {
builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges
builder.add((new IterativeOptimizer(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.isOptimizeUnionOverValues;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.sql.planner.plan.Patterns.sources;
import static com.facebook.presto.sql.planner.plan.Patterns.union;

public class MergeConstantValuesUnderUnion
implements Rule<UnionNode>
{
private static final Capture<List<PlanNode>> CHILDREN = newCapture();
private static final Pattern<UnionNode> PATTERN = union()
.with(sources().capturedAs(CHILDREN));

@Override
public Pattern<UnionNode> getPattern()
{
return PATTERN;
}

@Override
public boolean isEnabled(Session session)
{
return isOptimizeUnionOverValues(session);
}

@Override
public Result apply(UnionNode node, Captures captures, Context context)
{
List<PlanNode> values = captures.get(CHILDREN);
//Return if not union over ValuesNode
if (!(values.stream().map(x -> context.getLookup().resolve(x)).allMatch(x -> x instanceof ValuesNode))) {
return Result.empty();
}

ImmutableMap<VariableReferenceExpression, List<RowExpression>> inputVariablesToValuesMap = getInputVariablesToValuesMap(context, values);

List<List<RowExpression>> finalValues = generateRowsFromValues(node, context, values, inputVariablesToValuesMap);
return Result.ofPlanNode(
new ValuesNode(
node.getId(),
node.getOutputVariables(),
finalValues));
}

private List<List<RowExpression>> generateRowsFromValues(UnionNode node, Context context, List<PlanNode> values, Map<VariableReferenceExpression, List<RowExpression>> columnToValuesMap)
{
Map<VariableReferenceExpression, List<RowExpression>> outputVarToValuesMap = getOutputVariablesToValuesMap(node, columnToValuesMap);
List<Iterator> iterators = node.getOutputVariables().stream().map(x -> outputVarToValuesMap.get(x).iterator()).collect(Collectors.toList());

int totalRows = values.stream().map(x -> (ValuesNode) context.getLookup().resolve(x)).mapToInt(x -> x.getRows().size()).sum();
ImmutableList.Builder<List<RowExpression>> finalValues = ImmutableList.builder();

for (int i = 0; i < totalRows; i++) {
ImmutableList.Builder<RowExpression> row = ImmutableList.builder();
for (Iterator<RowExpression> iterator : iterators) {
row.add(iterator.next());
}
finalValues.add(row.build());
}
return finalValues.build();
}

private Map<VariableReferenceExpression, List<RowExpression>> getOutputVariablesToValuesMap(UnionNode node, Map<VariableReferenceExpression, List<RowExpression>> columnToValuesMap)
{
Map<VariableReferenceExpression, List<RowExpression>> outputVarToValuesMap = new HashMap<>();
ImmutableMap.Builder<VariableReferenceExpression, List<RowExpression>> outputVarToValuesMapBuilder = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, List<VariableReferenceExpression>> entry : node.getVariableMapping().entrySet()) {
VariableReferenceExpression outputVar = entry.getKey();
for (VariableReferenceExpression input : entry.getValue()) {
if (outputVarToValuesMap.get(outputVar) == null) {
outputVarToValuesMap.put(outputVar, new ArrayList<>());
}
outputVarToValuesMap.get(outputVar).addAll(columnToValuesMap.get(input));
}
}
return outputVarToValuesMap;
}

private ImmutableMap<VariableReferenceExpression, List<RowExpression>> getInputVariablesToValuesMap(Context context, List<PlanNode> values)
{
Map<VariableReferenceExpression, List<RowExpression>> columnToValuesMap = new HashMap<>();
for (PlanNode child : values) {
ValuesNode value = (ValuesNode) context.getLookup().resolve(child);
for (List<RowExpression> row : value.getRows()) {
for (int i = 0; i < row.size(); i++) {
List<RowExpression> colValues = columnToValuesMap.get(value.getOutputVariables().get(i));
if (colValues == null) {
colValues = new ArrayList<>();
}
colValues.add(row.get(i));
columnToValuesMap.put(value.getOutputVariables().get(i), colValues);
}
}
}
return ImmutableMap.copyOf(columnToValuesMap);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;

import static com.facebook.presto.SystemSessionProperties.isOptimizeUnionOverValues;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.planner.plan.Patterns.values;

public class PushConstantProjectIntoEmptyValuesNode
implements Rule<ProjectNode>
{
private static final Capture<ValuesNode> CHILD = newCapture();
private static final Pattern<ProjectNode> PATTERN = project()
.matching(x -> x.getAssignments().getExpressions().stream().allMatch(y -> y instanceof ConstantExpression))
.with(source().matching(values().capturedAs(CHILD)));

@Override
public Pattern<ProjectNode> getPattern()
{
return PATTERN;
}

@Override
public boolean isEnabled(Session session)
{
return isOptimizeUnionOverValues(session);
}

@Override
public Result apply(ProjectNode node, Captures captures, Context context)
{
ValuesNode child = captures.get(CHILD);
// Checking that ValuesNode has only one output row which is empty
if (child.getRows().isEmpty() || !child.getRows().get(0).isEmpty()) {
return Result.empty();
}

List<RowExpression> outputRow = new ArrayList<>(node.getOutputVariables().size());

for (int i = 0; i < node.getOutputVariables().size(); i++) {
outputRow.add(node.getAssignments().get(node.getOutputVariables().get(i)));
}

if (!child.getRows().isEmpty() && child.getRows().size() == 1 && child.getRows().get(0).isEmpty()) {
return Result.ofPlanNode(
new ValuesNode(
context.getIdAllocator().getNextId(),
node.getOutputVariables(),
ImmutableList.of(outputRow)));
}
return Result.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ public void testDefaults()
.setPreferDistributedUnion(true)
.setOptimizeNullsInJoin(false)
.setWarnOnNoTableLayoutFilter("")
.setInlineSqlFunctions(true));
.setInlineSqlFunctions(true)
.setOptimizeUnionOverValues(true));
}

@Test
Expand Down Expand Up @@ -245,6 +246,7 @@ public void testExplicitPropertyMappings()
.put("optimize-nulls-in-join", "true")
.put("warn-on-no-table-layout-filter", "ry@nlikestheyankees,ds")
.put("inline-sql-functions", "false")
.put("optimize-union-over-values", "false")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -341,7 +343,8 @@ public void testExplicitPropertyMappings()
.setPreferDistributedUnion(false)
.setOptimizeNullsInJoin(true)
.setWarnOnNoTableLayoutFilter("ry@nlikestheyankees,ds")
.setInlineSqlFunctions(false);
.setInlineSqlFunctions(false)
.setOptimizeUnionOverValues(false);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
import java.util.Optional;

import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window;
Expand All @@ -51,10 +49,8 @@ public void testJoin()
"CROSS JOIN (VALUES 1)",
anyTree(
join(INNER, ImmutableList.of(), Optional.empty(),
project(
ImmutableMap.of("X", expression("BIGINT '1'")),
values(ImmutableMap.of())),
values(ImmutableMap.of()))));
values("expr"),
values("field"))));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ public void testScalarSubqueryJoinFilterPushdown()
filter("orderkey = BIGINT '1'",
tableScan("orders", ImmutableMap.of("orderkey", "orderkey"))),
anyTree(
project(ImmutableMap.of("orderkey", expression("1")), any())))));
values("expr")))));
}

@Test
Expand Down Expand Up @@ -778,8 +778,7 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
join(LEFT, ImmutableList.of(), Optional.of("BIGINT '3' = ORDERKEY"),
any(
tableScan("orders", ImmutableMap.of("ORDERKEY", "orderkey"))),
project(ImmutableMap.of("NON_NULL", expression("true")),
node(ValuesNode.class)))))));
values("non_null"))))));
}

@Test
Expand Down
Loading

0 comments on commit b8b5441

Please sign in to comment.