Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize plans involving table functions #16012

Merged
merged 9 commits into from
Mar 30, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
import io.trino.sql.planner.iterative.rule.PruneSpatialJoinChildrenColumns;
import io.trino.sql.planner.iterative.rule.PruneSpatialJoinColumns;
import io.trino.sql.planner.iterative.rule.PruneTableExecuteSourceColumns;
import io.trino.sql.planner.iterative.rule.PruneTableFunctionProcessorColumns;
import io.trino.sql.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns;
import io.trino.sql.planner.iterative.rule.PruneTableScanColumns;
import io.trino.sql.planner.iterative.rule.PruneTableWriterSourceColumns;
import io.trino.sql.planner.iterative.rule.PruneTopNColumns;
Expand Down Expand Up @@ -198,6 +200,7 @@
import io.trino.sql.planner.iterative.rule.RemoveRedundantPredicateAboveTableScan;
import io.trino.sql.planner.iterative.rule.RemoveRedundantSort;
import io.trino.sql.planner.iterative.rule.RemoveRedundantSortBelowLimitWithTies;
import io.trino.sql.planner.iterative.rule.RemoveRedundantTableFunction;
import io.trino.sql.planner.iterative.rule.RemoveRedundantTopN;
import io.trino.sql.planner.iterative.rule.RemoveRedundantWindow;
import io.trino.sql.planner.iterative.rule.RemoveTrivialFilters;
Expand Down Expand Up @@ -438,6 +441,7 @@ public PlanOptimizers(
new RemoveRedundantOffset(),
new RemoveRedundantSort(),
new RemoveRedundantSortBelowLimitWithTies(),
new RemoveRedundantTableFunction(),
new RemoveRedundantTopN(),
new RemoveRedundantDistinctLimit(),
new ReplaceRedundantJoinWithSource(),
Expand Down Expand Up @@ -1034,6 +1038,8 @@ public static Set<Rule<?>> columnPruningRules(Metadata metadata)
new PruneSpatialJoinChildrenColumns(),
new PruneSpatialJoinColumns(),
new PruneTableExecuteSourceColumns(),
new PruneTableFunctionProcessorColumns(),
new PruneTableFunctionProcessorSourceColumns(),
new PruneTableScanColumns(metadata),
new PruneTableWriterSourceColumns(),
new PruneTopNColumns(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.planner.plan.Patterns.tableFunctionProcessor;

/**
* TableFunctionProcessorNode has two kinds of outputs:
* - proper outputs, which are the columns produced by the table function,
* - pass-through outputs, which are the columns copied from table arguments.
* This rule filters out unreferenced pass-through symbols.
* Unreferenced proper symbols are not pruned, because there is currently no way
* to communicate to the table function the request for not producing certain columns.
* // TODO prune table function's proper outputs
*/
public class PruneTableFunctionProcessorColumns
extends ProjectOffPushDownRule<TableFunctionProcessorNode>
{
public PruneTableFunctionProcessorColumns()
{
super(tableFunctionProcessor());
}

@Override
protected Optional<PlanNode> pushDownProjectOff(Context context, TableFunctionProcessorNode node, Set<Symbol> referencedOutputs)
{
List<PassThroughSpecification> prunedPassThroughSpecifications = node.getPassThroughSpecifications().stream()
.map(sourceSpecification -> {
List<PassThroughColumn> prunedPassThroughColumns = sourceSpecification.columns().stream()
.filter(column -> referencedOutputs.contains(column.symbol()))
.collect(toImmutableList());
return new PassThroughSpecification(sourceSpecification.declaredAsPassThrough(), prunedPassThroughColumns);
})
.collect(toImmutableList());

int originalPassThroughCount = node.getPassThroughSpecifications().stream()
.map(PassThroughSpecification::columns)
.mapToInt(List::size)
.sum();

int prunedPassThroughCount = prunedPassThroughSpecifications.stream()
.map(PassThroughSpecification::columns)
.mapToInt(List::size)
.sum();

if (originalPassThroughCount == prunedPassThroughCount) {
return Optional.empty();
}

return Optional.of(new TableFunctionProcessorNode(
node.getId(),
node.getName(),
node.getFunctionCatalog(),
node.getProperOutputs(),
node.getSource(),
node.isPruneWhenEmpty(),
prunedPassThroughSpecifications,
node.getRequiredSymbols(),
node.getMarkerSymbols(),
node.getSpecification(),
node.getPrePartitioned(),
node.getPreSorted(),
node.getHashSymbol(),
node.getHandle()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.Maps.filterKeys;
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.plan.Patterns.tableFunctionProcessor;

/**
* This rule prunes unreferenced outputs of TableFunctionProcessorNode.
* First, it extracts all symbols required for:
* - pass-through
* - table function computation
* - partitioning and ordering (including the hashSymbol)
* Next, a mapping of input symbols to marker symbols is updated
* so that it only contains mappings for the required symbols.
* Last, all the remaining marker symbols are added to the collection
* of required symbols.
* Any source output symbols not included in the required symbols
* can be pruned.
*/
public class PruneTableFunctionProcessorSourceColumns
implements Rule<TableFunctionProcessorNode>
{
private static final Pattern<TableFunctionProcessorNode> PATTERN = tableFunctionProcessor();

@Override
public Pattern<TableFunctionProcessorNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(TableFunctionProcessorNode node, Captures captures, Context context)
{
if (node.getSource().isEmpty()) {
return Result.empty();
}

ImmutableSet.Builder<Symbol> requiredInputs = ImmutableSet.builder();

node.getPassThroughSpecifications().stream()
.map(PassThroughSpecification::columns)
.flatMap(Collection::stream)
.map(PassThroughColumn::symbol)
.forEach(requiredInputs::add);

node.getRequiredSymbols().stream()
.forEach(requiredInputs::addAll);

node.getSpecification().ifPresent(specification -> {
requiredInputs.addAll(specification.getPartitionBy());
specification.getOrderingScheme().ifPresent(orderingScheme -> requiredInputs.addAll(orderingScheme.getOrderBy()));
});

node.getHashSymbol().ifPresent(requiredInputs::add);

Optional<Map<Symbol, Symbol>> updatedMarkerSymbols = node.getMarkerSymbols()
.map(mapping -> filterKeys(mapping, requiredInputs.build()::contains));

updatedMarkerSymbols.ifPresent(mapping -> requiredInputs.addAll(mapping.values()));

return restrictOutputs(context.getIdAllocator(), node.getSource().orElseThrow(), requiredInputs.build())
.map(child -> Result.ofPlanNode(new TableFunctionProcessorNode(
node.getId(),
node.getName(),
node.getFunctionCatalog(),
node.getProperOutputs(),
Optional.of(child),
node.isPruneWhenEmpty(),
node.getPassThroughSpecifications(),
node.getRequiredSymbols(),
updatedMarkerSymbols,
node.getSpecification(),
node.getPrePartitioned(),
node.getPreSorted(),
node.getHashSymbol(),
node.getHandle())))
.orElse(Result.empty());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;
import io.trino.sql.planner.plan.ValuesNode;

import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isEmpty;
import static io.trino.sql.planner.plan.Patterns.tableFunctionProcessor;

/**
* Table function can take multiple table arguments. Each argument is either "prune when empty" or "keep when empty".
* "Prune when empty" means that if this argument has no rows, the function result is empty, so the function can be
* removed from the plan, and replaced with empty values.
* "Keep when empty" means that even if the argument has no rows, the function should still be executed, and it can
* return a non-empty result.
* All the table arguments are combined into a single source of a TableFunctionProcessorNode. If either argument is
* "prune when empty", the overall result is "prune when empty". This rule removes a redundant TableFunctionProcessorNode
* based on the "prune when empty" property.
*/
public class RemoveRedundantTableFunction
implements Rule<TableFunctionProcessorNode>
{
private static final Pattern<TableFunctionProcessorNode> PATTERN = tableFunctionProcessor();

@Override
public Pattern<TableFunctionProcessorNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(TableFunctionProcessorNode node, Captures captures, Context context)
{
if (node.isPruneWhenEmpty() && node.getSource().isPresent()) {
if (isEmpty(node.getSource().orElseThrow(), context.getLookup())) {
return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
}
}

return Result.empty();
}
}
Loading