Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPI and optimizer rule for connectors that can support complete topN … #4249

Merged
merged 1 commit into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions presto-main/src/main/java/io/prestosql/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
Expand Down Expand Up @@ -367,6 +369,13 @@ Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
Map<String, ColumnHandle> assignments,
List<List<ColumnHandle>> groupingSets);

Optional<TopNApplicationResult<TableHandle>> applyTopN(
Session session,
TableHandle handle,
long topNCount,
List<SortItem> sortItems,
Map<String, ColumnHandle> assignments);

default void validateScan(Session session, TableHandle table) {}

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.spi.function.InvocationConvention;
Expand Down Expand Up @@ -1160,6 +1162,28 @@ public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(
});
}

@Override
public Optional<TopNApplicationResult<TableHandle>> applyTopN(
Session session,
TableHandle table,
long topNCount,
List<SortItem> sortItems,
Map<String, ColumnHandle> assignments)
{
CatalogName catalogName = table.getCatalogName();
ConnectorMetadata metadata = getMetadata(session, catalogName);

if (metadata.usesLegacyTableLayouts()) {
return Optional.empty();
}

ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyTopN(connectorSession, table.getConnectorHandle(), topNCount, sortItems, assignments)
.map(result -> new TopNApplicationResult<>(
new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()),
result.isTopNGuaranteed()));
}

private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
{
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,13 @@ public static SortOrder sortItemToSortOrder(SortItem sortItem)
}
return SortOrder.DESC_NULLS_LAST;
}

public List<io.prestosql.spi.connector.SortItem> toSortItems()
{
return getOrderBy().stream()
.map(symbol -> new io.prestosql.spi.connector.SortItem(
symbol.getName(),
io.prestosql.spi.connector.SortOrder.valueOf(getOrdering(symbol).name())))
.collect(toImmutableList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
import io.prestosql.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
import io.prestosql.sql.planner.iterative.rule.PushSampleIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import io.prestosql.sql.planner.iterative.rule.PushTopNIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughProject;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughUnion;
Expand Down Expand Up @@ -623,7 +624,8 @@ public PlanOptimizers(
new CreatePartialTopN(),
new PushTopNThroughProject(),
new PushTopNThroughOuterJoin(),
new PushTopNThroughUnion())));
new PushTopNThroughUnion(),
new PushTopNIntoTableScan(metadata))));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added to pushIntoTableScanOptimizer, when other PushXxxIntoTableScan rules are registered.

This is important for two reasons:

  • this will unlock other pushdowns once TopN is pushed into TableScan fully
  • this will let the rule operate on TopNNode while it is SINGLE step (and consume it fully), making reasoning about engine-connector interactions simpler

@wendigo is going to address this in #6847

builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SortOrder;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.expression.Variable;
import io.prestosql.sql.planner.ConnectorExpressionTranslator;
Expand Down Expand Up @@ -200,12 +199,7 @@ private static AggregateFunction toAggregateFunction(Context context, Aggregatio
}

Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
Optional<List<SortItem>> sortBy = orderingScheme.map(orderings ->
orderings.getOrderBy().stream()
.map(orderBy -> new SortItem(
orderBy.getName(),
SortOrder.valueOf(orderings.getOrderings().get(orderBy).name())))
.collect(toImmutableList()));
Optional<List<SortItem>> sortBy = orderingScheme.map(OrderingScheme::toSortItems);

Optional<ConnectorExpression> filter = aggregation.getFilter()
.map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol)));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TopNNode;

import java.util.List;
import java.util.Map;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.sql.planner.plan.Patterns.source;
import static io.prestosql.sql.planner.plan.Patterns.tableScan;
import static io.prestosql.sql.planner.plan.Patterns.topN;

public class PushTopNIntoTableScan
implements Rule<TopNNode>
{
private static final Capture<TableScanNode> TABLE_SCAN = newCapture();

private static final Pattern<TopNNode> PATTERN = topN().with(source().matching(
tableScan().capturedAs(TABLE_SCAN)));
Comment on lines +41 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should trigger no SINGLE step only.
We should not push partial and final separately (and IMO -- we should not push them at all)


private final Metadata metadata;

public PushTopNIntoTableScan(Metadata metadata)
{
this.metadata = metadata;
}

@Override
public Pattern<TopNNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(TopNNode topNNode, Captures captures, Context context)
{
TableScanNode tableScan = captures.get(TABLE_SCAN);

long topNCount = topNNode.getCount();
List<SortItem> sortItems = topNNode.getOrderingScheme().toSortItems();

Map<String, ColumnHandle> assignments = tableScan.getAssignments()
.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));

return metadata.applyTopN(context.getSession(), tableScan.getTable(), topNCount, sortItems, assignments)
.map(result -> {
PlanNode node = TableScanNode.newInstance(
context.getIdAllocator().getNextId(),
result.getHandle(),
tableScan.getOutputSymbols(),
tableScan.getAssignments());

if (!result.isTopNGuaranteed()) {
node = new TopNNode(topNNode.getId(), node, topNNode.getCount(), topNNode.getOrderingScheme(), TopNNode.Step.FINAL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relates to #4249 (comment)

Inserting FINAL step is not correct when the rule triggered on TopNNode with Step PARTIAL.

I suggest

  • make the rule trigger for SINGLE step only (my preferred)
    • i am aware sibling PushLimitIntoTableScan triggers for partial limits, but i am not convinced it is beneficial
  • keep same step as it used to be used (basically, use topNNode.replaceChildren)

(this is not being addressed in @wendigo #6847)

}
return Result.ofPlanNode(node);
})
.orElseGet(Result::empty);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.eventlistener.EventListener;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.security.PrestoPrincipal;
Expand Down Expand Up @@ -75,6 +77,7 @@ public class MockConnectorFactory
private final BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle;
private final Function<SchemaTableName, List<ColumnMetadata>> getColumns;
private final ApplyProjection applyProjection;
private final ApplyTopN applyTopN;
private final BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout;
private final BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout;
private final Supplier<Iterable<EventListener>> eventListeners;
Expand All @@ -87,6 +90,7 @@ private MockConnectorFactory(
BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle,
Function<SchemaTableName, List<ColumnMetadata>> getColumns,
ApplyProjection applyProjection,
ApplyTopN applyTopN,
BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout,
BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout,
Supplier<Iterable<EventListener>> eventListeners,
Expand All @@ -98,6 +102,7 @@ private MockConnectorFactory(
this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null");
this.getColumns = getColumns;
this.applyProjection = applyProjection;
this.applyTopN = requireNonNull(applyTopN, "applyTopN is null");
this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null");
this.getNewTableLayout = requireNonNull(getNewTableLayout, "getNewTableLayout is null");
this.eventListeners = requireNonNull(eventListeners, "eventListeners is null");
Expand All @@ -119,7 +124,7 @@ public ConnectorHandleResolver getHandleResolver()
@Override
public Connector create(String catalogName, Map<String, String> config, ConnectorContext context)
{
return new MockConnector(context, listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
return new MockConnector(context, listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, applyTopN, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
}

public static Builder builder()
Expand All @@ -133,6 +138,12 @@ public interface ApplyProjection
Optional<ProjectionApplicationResult<ConnectorTableHandle>> apply(ConnectorSession session, ConnectorTableHandle handle, List<ConnectorExpression> projections, Map<String, ColumnHandle> assignments);
}

@FunctionalInterface
public interface ApplyTopN
{
Optional<TopNApplicationResult<ConnectorTableHandle>> apply(ConnectorSession session, ConnectorTableHandle handle, long topNCount, List<SortItem> sortItems, Map<String, ColumnHandle> assignments);
}

@FunctionalInterface
public interface ListRoleGrants
{
Expand All @@ -149,6 +160,7 @@ public static class MockConnector
private final BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle;
private final Function<SchemaTableName, List<ColumnMetadata>> getColumns;
private final ApplyProjection applyProjection;
private final ApplyTopN applyTopN;
private final BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout;
private final BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout;
private final Supplier<Iterable<EventListener>> eventListeners;
Expand All @@ -162,6 +174,7 @@ private MockConnector(
BiFunction<ConnectorSession, SchemaTableName, ConnectorTableHandle> getTableHandle,
Function<SchemaTableName, List<ColumnMetadata>> getColumns,
ApplyProjection applyProjection,
ApplyTopN applyTopN,
BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout,
BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout,
Supplier<Iterable<EventListener>> eventListeners,
Expand All @@ -174,6 +187,7 @@ private MockConnector(
this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null");
this.getColumns = requireNonNull(getColumns, "getColumns is null");
this.applyProjection = requireNonNull(applyProjection, "applyProjection is null");
this.applyTopN = requireNonNull(applyTopN, "applyTopN is null");
this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null");
this.getNewTableLayout = requireNonNull(getNewTableLayout, "getNewTableLayout is null");
this.eventListeners = requireNonNull(eventListeners, "eventListeners is null");
Expand Down Expand Up @@ -219,6 +233,12 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
return applyProjection.apply(session, handle, projections, assignments);
}

@Override
public Optional<TopNApplicationResult<ConnectorTableHandle>> applyTopN(ConnectorSession session, ConnectorTableHandle handle, long topNCount, List<SortItem> sortItems, Map<String, ColumnHandle> assignments)
{
return applyTopN.apply(session, handle, topNCount, sortItems, assignments);
}

@Override
public List<String> listSchemaNames(ConnectorSession session)
{
Expand Down Expand Up @@ -414,6 +434,7 @@ public static final class Builder
private BiFunction<ConnectorSession, ConnectorTableMetadata, Optional<ConnectorNewTableLayout>> getNewTableLayout = defaultGetNewTableLayout();
private Supplier<Iterable<EventListener>> eventListeners = ImmutableList::of;
private ListRoleGrants roleGrants = defaultRoleAuthorizations();
private ApplyTopN applyTopN = (session, handle, topNCount, sortItems, assignments) -> Optional.empty();

public Builder withListSchemaNames(Function<ConnectorSession, List<String>> listSchemaNames)
{
Expand Down Expand Up @@ -457,6 +478,12 @@ public Builder withApplyProjection(ApplyProjection applyProjection)
return this;
}

public Builder withApplyTopN(ApplyTopN applyTopN)
{
this.applyTopN = applyTopN;
return this;
}

public Builder withGetInsertLayout(BiFunction<ConnectorSession, SchemaTableName, Optional<ConnectorNewTableLayout>> getInsertLayout)
{
this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null");
Expand Down Expand Up @@ -487,7 +514,7 @@ public Builder withEventListener(Supplier<EventListener> listenerFactory)

public MockConnectorFactory build()
{
return new MockConnectorFactory(listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
return new MockConnectorFactory(listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, applyTopN, getInsertLayout, getNewTableLayout, eventListeners, roleGrants);
}

public static Function<ConnectorSession, List<String>> defaultListSchemaNames()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.connector.SampleType;
import io.prestosql.spi.connector.SortItem;
import io.prestosql.spi.connector.SystemTable;
import io.prestosql.spi.connector.TopNApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.OperatorType;
Expand Down Expand Up @@ -741,4 +743,10 @@ public Optional<ProjectionApplicationResult<TableHandle>> applyProjection(Sessio
{
return Optional.empty();
}

@Override
public Optional<TopNApplicationResult<TableHandle>> applyTopN(Session session, TableHandle handle, long topNFunctions, List<SortItem> sortItems, Map<String, ColumnHandle> assignments)
{
return Optional.empty();
}
}
Loading