Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plan table function invocation with table arguments #14175

Merged
merged 6 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.DeleteNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.GroupIdNode;
Expand Down Expand Up @@ -328,7 +329,7 @@ public RelationPlan planExpand(Query query)
WindowNode windowNode = new WindowNode(
idAllocator.getNextId(),
checkConvergenceStep.getNode(),
new WindowNode.Specification(ImmutableList.of(), Optional.empty()),
new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()),
ImmutableMap.of(countSymbol, countFunction),
Optional.empty(),
ImmutableSet.of(),
Expand Down Expand Up @@ -1413,7 +1414,7 @@ private <T extends Expression> List<T> scopeAwareDistinct(PlanBuilder subPlan, L
.collect(toImmutableList());
}

private static OrderingScheme translateOrderingScheme(List<SortItem> items, Function<Expression, Symbol> coercions)
public static OrderingScheme translateOrderingScheme(List<SortItem> items, Function<Expression, Symbol> coercions)
{
List<Symbol> coerced = items.stream()
.map(SortItem::getSortKey)
Expand Down Expand Up @@ -1829,7 +1830,7 @@ private PlanBuilder planWindow(
}
}

WindowNode.Specification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get);
DataOrganizationSpecification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get);

// Rewrite frame bounds in terms of pre-projected inputs
WindowNode.Frame frame = new WindowNode.Frame(
Expand Down Expand Up @@ -1882,7 +1883,7 @@ private PlanBuilder planPatternRecognition(
PlanAndMappings coercions,
Optional<Symbol> frameEndSymbol)
{
WindowNode.Specification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get);
DataOrganizationSpecification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), coercions::get);

// in window frame with pattern recognition, the frame extent is specified as `ROWS BETWEEN CURRENT ROW AND ... `
WindowFrame frame = window.getFrame().orElseThrow();
Expand Down Expand Up @@ -1949,7 +1950,7 @@ private PlanBuilder planPatternRecognition(
components.getVariableDefinitions()));
}

public static WindowNode.Specification planWindowSpecification(List<Expression> partitionBy, Optional<OrderBy> orderBy, Function<Expression, Symbol> expressionRewrite)
public static DataOrganizationSpecification planWindowSpecification(List<Expression> partitionBy, Optional<OrderBy> orderBy, Function<Expression, Symbol> expressionRewrite)
{
// Rewrite PARTITION BY
ImmutableList.Builder<Symbol> partitionBySymbols = ImmutableList.builder();
Expand All @@ -1970,7 +1971,7 @@ public static WindowNode.Specification planWindowSpecification(List<Expression>
orderingScheme = Optional.of(new OrderingScheme(ImmutableList.copyOf(orderings.keySet()), orderings));
}

return new WindowNode.Specification(partitionBySymbols.build(), orderingScheme);
return new DataOrganizationSpecification(partitionBySymbols.build(), orderingScheme);
}

private PlanBuilder planWindowMeasures(Node node, PlanBuilder subPlan, List<WindowOperation> windowMeasures)
Expand Down Expand Up @@ -2031,7 +2032,7 @@ private PlanBuilder planPatternRecognition(
ResolvedWindow window,
Optional<Symbol> frameEndSymbol)
{
WindowNode.Specification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), subPlan::translate);
DataOrganizationSpecification specification = planWindowSpecification(window.getPartitionBy(), window.getOrderBy(), subPlan::translate);

// in window frame with pattern recognition, the frame extent is specified as `ROWS BETWEEN CURRENT ROW AND ... `
WindowFrame frame = window.getFrame().orElseThrow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import io.trino.Session;
Expand All @@ -27,13 +28,16 @@
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.Analysis.TableArgumentAnalysis;
import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis;
import io.trino.sql.analyzer.Analysis.UnnestAnalysis;
import io.trino.sql.analyzer.Field;
import io.trino.sql.analyzer.RelationType;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.planner.QueryPlanner.PlanAndMappings;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.ExceptNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.IntersectNode;
Expand All @@ -49,7 +53,6 @@
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.RowPatternToIrRewriter;
Expand Down Expand Up @@ -88,9 +91,7 @@
import io.trino.sql.tree.SubqueryExpression;
import io.trino.sql.tree.SubsetDefinition;
import io.trino.sql.tree.Table;
import io.trino.sql.tree.TableFunctionDescriptorArgument;
import io.trino.sql.tree.TableFunctionInvocation;
import io.trino.sql.tree.TableFunctionTableArgument;
import io.trino.sql.tree.TableSubquery;
import io.trino.sql.tree.Union;
import io.trino.sql.tree.Unnest;
Expand All @@ -106,6 +107,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -122,6 +124,7 @@
import static io.trino.sql.planner.QueryPlanner.extractPatternRecognitionExpressions;
import static io.trino.sql.planner.QueryPlanner.planWindowSpecification;
import static io.trino.sql.planner.QueryPlanner.pruneInvisibleFields;
import static io.trino.sql.planner.QueryPlanner.translateOrderingScheme;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
Expand Down Expand Up @@ -329,46 +332,99 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan)
@Override
protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, Void context)
{
node.getArguments().stream()
.forEach(argument -> {
if (argument.getValue() instanceof TableFunctionTableArgument) {
throw semanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions");
}
if (argument.getValue() instanceof TableFunctionDescriptorArgument) {
throw semanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions");
}
});

TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node);

// TODO handle input relations:
// 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested.
// 2. for each input relation, prepare the TableArgumentProperties record, consisting of:
// - row or set semantics (from the actualArgument)
// - prune when empty property (from the actualArgument)
// - pass through columns property (from the actualArgument)
// - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources)
// TODO add - argument name
// TODO add - mapping column name => Symbol // TODO mind the fields without names and duplicate field names in RelationType
List<RelationPlan> sources = ImmutableList.of();
List<TableArgumentProperties> inputRelationsProperties = ImmutableList.of();
ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();
ImmutableList.Builder<TableArgumentProperties> sourceProperties = ImmutableList.builder();
ImmutableList.Builder<Symbol> outputSymbols = ImmutableList.builder();

Scope scope = analysis.getScope(node);
// TODO pass columns from input relations, and make sure they have the right qualifier
List<Symbol> outputSymbols = scope.getRelationType().getAllFields().stream()
// create new symbols for table function's proper columns
RelationType relationType = analysis.getScope(node).getRelationType();
List<Symbol> properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount())
.mapToObj(relationType::getFieldByIndex)
.map(symbolAllocator::newSymbol)
.collect(toImmutableList());

outputSymbols.addAll(properOutputs);

// process sources in order of argument declarations
for (TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) {
RelationPlan sourcePlan = process(tableArgument.getRelation(), context);
PlanBuilder sourcePlanBuilder = newPlanBuilder(sourcePlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext);

// map column names to symbols
// note: hidden columns are included in the mapping. They are present both in sourceDescriptor.allFields, and in sourcePlan.fieldMappings
// note: for an aliased relation or a CTE, the field names in the relation type are in the same case as specified in the alias.
// quotes and canonicalization rules are not applied.
ImmutableMultimap.Builder<String, Symbol> columnMapping = ImmutableMultimap.builder();
RelationType sourceDescriptor = sourcePlan.getDescriptor();
for (int i = 0; i < sourceDescriptor.getAllFieldCount(); i++) {
Optional<String> name = sourceDescriptor.getFieldByIndex(i).getName();
if (name.isPresent()) {
columnMapping.put(name.get(), sourcePlan.getSymbol(i));
}
}

Optional<DataOrganizationSpecification> specification = Optional.empty();

// if the table argument has set semantics, create Specification
if (!tableArgument.isRowSemantics()) {
// partition by
List<Symbol> partitionBy = ImmutableList.of();
// if there are partitioning columns, they might have to be coerced for copartitioning
if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) {
List<Expression> partitioningColumns = tableArgument.getPartitionBy().get();
PlanAndMappings copartitionCoercions = coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, symbolAllocator, typeCoercion);
sourcePlanBuilder = copartitionCoercions.getSubPlan();
partitionBy = partitioningColumns.stream()
.map(copartitionCoercions::get)
.collect(toImmutableList());
}

// order by
Optional<OrderingScheme> orderBy = Optional.empty();
if (tableArgument.getOrderBy().isPresent()) {
// the ordering symbols are not coerced
orderBy = Optional.of(translateOrderingScheme(tableArgument.getOrderBy().get().getSortItems(), sourcePlanBuilder::translate));
}

specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy));
}

sources.add(sourcePlanBuilder.getRoot());
sourceProperties.add(new TableArgumentProperties(
tableArgument.getArgumentName(),
columnMapping.build(),
tableArgument.isRowSemantics(),
tableArgument.isPruneWhenEmpty(),
tableArgument.isPassThroughColumns(),
specification));

// add output symbols passed from the table argument
if (tableArgument.isPassThroughColumns()) {
// the original output symbols from the source node, not coerced
// note: hidden columns are included. They are present in sourcePlan.fieldMappings
outputSymbols.addAll(sourcePlan.getFieldMappings());
}
else if (tableArgument.getPartitionBy().isPresent()) {
tableArgument.getPartitionBy().get().stream()
// the original symbols for partitioning columns, not coerced
.map(sourcePlanBuilder::translate)
.forEach(outputSymbols::add);
}
}

PlanNode root = new TableFunctionNode(
idAllocator.getNextId(),
functionAnalysis.getFunctionName(),
functionAnalysis.getArguments(),
outputSymbols,
sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()),
inputRelationsProperties,
properOutputs,
sources.build(),
sourceProperties.build(),
functionAnalysis.getCopartitioningLists(),
new TableFunctionHandle(functionAnalysis.getCatalogHandle(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle()));

return new RelationPlan(root, scope, outputSymbols, outerContext);
return new RelationPlan(root, analysis.getScope(node), outputSymbols.build(), outerContext);
}

@Override
Expand Down Expand Up @@ -416,7 +472,7 @@ protected RelationPlan visitPatternRecognitionRelation(PatternRecognitionRelatio
ImmutableList.Builder<Symbol> outputLayout = ImmutableList.builder();
boolean oneRowOutput = node.getRowsPerMatch().isEmpty() || node.getRowsPerMatch().get().isOneRow();

WindowNode.Specification specification = planWindowSpecification(node.getPartitionBy(), node.getOrderBy(), planBuilder::translate);
DataOrganizationSpecification specification = planWindowSpecification(node.getPartitionBy(), node.getOrderBy(), planBuilder::translate);
outputLayout.addAll(specification.getPartitionBy());
if (!oneRowOutput) {
getSortItemsFromOrderBy(node.getOrderBy()).stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode.Type;
Expand All @@ -41,7 +42,6 @@
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.plan.WindowNode.Specification;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
Expand Down Expand Up @@ -473,7 +473,7 @@ public RewriteResult visitTopN(TopNNode node, Void context)
WindowNode windowNode = new WindowNode(
idAllocator.getNextId(),
source.getPlan(),
new Specification(ImmutableList.of(uniqueSymbol), Optional.of(node.getOrderingScheme())),
new DataOrganizationSpecification(ImmutableList.of(uniqueSymbol), Optional.of(node.getOrderingScheme())),
ImmutableMap.of(rowNumberSymbol, rowNumberFunction),
Optional.empty(),
ImmutableSet.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.PlanNode;
Expand Down Expand Up @@ -128,7 +129,7 @@ public static PlanNode rewriteLimitWithTiesWithPartitioning(LimitNode limitNode,
WindowNode windowNode = new WindowNode(
idAllocator.getNextId(),
source,
new WindowNode.Specification(partitionBy, limitNode.getTiesResolvingScheme()),
new DataOrganizationSpecification(partitionBy, limitNode.getTiesResolvingScheme()),
ImmutableMap.of(rankSymbol, rankFunction),
Optional.empty(),
ImmutableSet.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SubscriptExpression;
import io.trino.sql.tree.SymbolReference;
Expand Down Expand Up @@ -89,7 +89,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context)
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());

// Exclude dereferences on symbols being used in partitionBy and orderBy
WindowNode.Specification specification = topNRankingNode.getSpecification();
DataOrganizationSpecification specification = topNRankingNode.getSpecification();
dereferences = dereferences.stream()
.filter(expression -> {
Symbol symbol = getBase(expression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.Expression;
Expand Down Expand Up @@ -99,7 +100,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context)
typeAnalyzer,
context.getSymbolAllocator().getTypes());

WindowNode.Specification specification = windowNode.getSpecification();
DataOrganizationSpecification specification = windowNode.getSpecification();
dereferences = dereferences.stream()
.filter(expression -> {
Symbol symbol = getBase(expression);
Expand Down
Loading