From 9c94323428ea17bd492b0e809456dc2b164ba009 Mon Sep 17 00:00:00 2001 From: pbrahmbhatt Date: Fri, 26 Jun 2020 14:28:59 -0700 Subject: [PATCH] Add support for TopN pushdown --- .../java/io/prestosql/metadata/Metadata.java | 9 + .../prestosql/metadata/MetadataManager.java | 24 +++ .../prestosql/sql/planner/OrderingScheme.java | 9 + .../prestosql/sql/planner/PlanOptimizers.java | 4 +- .../rule/PushAggregationIntoTableScan.java | 8 +- .../iterative/rule/PushTopNIntoTableScan.java | 84 ++++++++ .../connector/MockConnectorFactory.java | 31 ++- .../metadata/AbstractMockMetadata.java | 8 + .../rule/TestPushTopNIntoTableScan.java | 190 ++++++++++++++++++ .../ClassLoaderSafeConnectorMetadata.java | 15 ++ .../spi/connector/ConnectorMetadata.java | 24 +++ .../spi/connector/TopNApplicationResult.java | 38 ++++ 12 files changed, 434 insertions(+), 10 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNIntoTableScan.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNIntoTableScan.java create mode 100644 presto-spi/src/main/java/io/prestosql/spi/connector/TopNApplicationResult.java diff --git a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java index 540e41319272..b87f39c88e0f 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/Metadata.java +++ b/presto-main/src/main/java/io/prestosql/metadata/Metadata.java @@ -35,7 +35,9 @@ import io.prestosql.spi.connector.LimitApplicationResult; import io.prestosql.spi.connector.ProjectionApplicationResult; import io.prestosql.spi.connector.SampleType; +import io.prestosql.spi.connector.SortItem; import io.prestosql.spi.connector.SystemTable; +import io.prestosql.spi.connector.TopNApplicationResult; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.function.OperatorType; @@ -367,6 +369,13 @@ Optional> applyAggregation( Map assignments, List> groupingSets); + Optional> applyTopN( + Session session, + TableHandle handle, + long topNCount, + List sortItems, + Map assignments); + default void validateScan(Session session, TableHandle table) {} // diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 3f009032496e..cc77a6a9493e 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -74,7 +74,9 @@ import io.prestosql.spi.connector.SampleType; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; +import io.prestosql.spi.connector.SortItem; import io.prestosql.spi.connector.SystemTable; +import io.prestosql.spi.connector.TopNApplicationResult; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.expression.Variable; import io.prestosql.spi.function.InvocationConvention; @@ -1160,6 +1162,28 @@ public Optional> applyAggregation( }); } + @Override + public Optional> applyTopN( + Session session, + TableHandle table, + long topNCount, + List sortItems, + Map assignments) + { + CatalogName catalogName = table.getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + + if (metadata.usesLegacyTableLayouts()) { + return Optional.empty(); + } + + ConnectorSession connectorSession = session.toConnectorSession(catalogName); + return metadata.applyTopN(connectorSession, table.getConnectorHandle(), topNCount, sortItems, assignments) + .map(result -> new TopNApplicationResult<>( + new TableHandle(catalogName, result.getHandle(), table.getTransaction(), Optional.empty()), + result.isTopNGuaranteed())); + } + private void verifyProjection(TableHandle table, List projections, List assignments, int expectedProjectionSize) { projections.forEach(projection -> requireNonNull(projection, "one of the projections is null")); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/OrderingScheme.java b/presto-main/src/main/java/io/prestosql/sql/planner/OrderingScheme.java index fd990f9598a2..5ace2bdd21a6 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/OrderingScheme.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/OrderingScheme.java @@ -129,4 +129,13 @@ public static SortOrder sortItemToSortOrder(SortItem sortItem) } return SortOrder.DESC_NULLS_LAST; } + + public List toSortItems() + { + return getOrderBy().stream() + .map(symbol -> new io.prestosql.spi.connector.SortItem( + symbol.getName(), + io.prestosql.spi.connector.SortOrder.valueOf(getOrdering(symbol).name()))) + .collect(toImmutableList()); + } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 109e45791fde..7efd9579e096 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -140,6 +140,7 @@ import io.prestosql.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId; import io.prestosql.sql.planner.iterative.rule.PushSampleIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushTableWriteThroughUnion; +import io.prestosql.sql.planner.iterative.rule.PushTopNIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushTopNThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushTopNThroughProject; import io.prestosql.sql.planner.iterative.rule.PushTopNThroughUnion; @@ -623,7 +624,8 @@ public PlanOptimizers( new CreatePartialTopN(), new PushTopNThroughProject(), new PushTopNThroughOuterJoin(), - new PushTopNThroughUnion()))); + new PushTopNThroughUnion(), + new PushTopNIntoTableScan(metadata)))); builder.add(new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 827d4d42494f..f122adb2bb35 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -26,7 +26,6 @@ import io.prestosql.spi.connector.Assignment; import io.prestosql.spi.connector.ColumnHandle; import io.prestosql.spi.connector.SortItem; -import io.prestosql.spi.connector.SortOrder; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.expression.Variable; import io.prestosql.sql.planner.ConnectorExpressionTranslator; @@ -200,12 +199,7 @@ private static AggregateFunction toAggregateFunction(Context context, Aggregatio } Optional orderingScheme = aggregation.getOrderingScheme(); - Optional> sortBy = orderingScheme.map(orderings -> - orderings.getOrderBy().stream() - .map(orderBy -> new SortItem( - orderBy.getName(), - SortOrder.valueOf(orderings.getOrderings().get(orderBy).name()))) - .collect(toImmutableList())); + Optional> sortBy = orderingScheme.map(OrderingScheme::toSortItems); Optional filter = aggregation.getFilter() .map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol))); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNIntoTableScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNIntoTableScan.java new file mode 100644 index 000000000000..7901f8b4dab0 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNIntoTableScan.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.metadata.Metadata; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.SortItem; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.planner.plan.TopNNode; + +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.tableScan; +import static io.prestosql.sql.planner.plan.Patterns.topN; + +public class PushTopNIntoTableScan + implements Rule +{ + private static final Capture TABLE_SCAN = newCapture(); + + private static final Pattern PATTERN = topN().with(source().matching( + tableScan().capturedAs(TABLE_SCAN))); + + private final Metadata metadata; + + public PushTopNIntoTableScan(Metadata metadata) + { + this.metadata = metadata; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TopNNode topNNode, Captures captures, Context context) + { + TableScanNode tableScan = captures.get(TABLE_SCAN); + + long topNCount = topNNode.getCount(); + List sortItems = topNNode.getOrderingScheme().toSortItems(); + + Map assignments = tableScan.getAssignments() + .entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)); + + return metadata.applyTopN(context.getSession(), tableScan.getTable(), topNCount, sortItems, assignments) + .map(result -> { + PlanNode node = TableScanNode.newInstance( + context.getIdAllocator().getNextId(), + result.getHandle(), + tableScan.getOutputSymbols(), + tableScan.getAssignments()); + + if (!result.isTopNGuaranteed()) { + node = new TopNNode(topNNode.getId(), node, topNNode.getCount(), topNNode.getOrderingScheme(), TopNNode.Step.FINAL); + } + return Result.ofPlanNode(node); + }) + .orElseGet(Result::empty); + } +} diff --git a/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java b/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java index 64879b019233..1c2f4b727521 100644 --- a/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java +++ b/presto-main/src/test/java/io/prestosql/connector/MockConnectorFactory.java @@ -43,6 +43,8 @@ import io.prestosql.spi.connector.ProjectionApplicationResult; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; +import io.prestosql.spi.connector.SortItem; +import io.prestosql.spi.connector.TopNApplicationResult; import io.prestosql.spi.eventlistener.EventListener; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.security.PrestoPrincipal; @@ -75,6 +77,7 @@ public class MockConnectorFactory private final BiFunction getTableHandle; private final Function> getColumns; private final ApplyProjection applyProjection; + private final ApplyTopN applyTopN; private final BiFunction> getInsertLayout; private final BiFunction> getNewTableLayout; private final Supplier> eventListeners; @@ -87,6 +90,7 @@ private MockConnectorFactory( BiFunction getTableHandle, Function> getColumns, ApplyProjection applyProjection, + ApplyTopN applyTopN, BiFunction> getInsertLayout, BiFunction> getNewTableLayout, Supplier> eventListeners, @@ -98,6 +102,7 @@ private MockConnectorFactory( this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null"); this.getColumns = getColumns; this.applyProjection = applyProjection; + this.applyTopN = requireNonNull(applyTopN, "applyTopN is null"); this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null"); this.getNewTableLayout = requireNonNull(getNewTableLayout, "getNewTableLayout is null"); this.eventListeners = requireNonNull(eventListeners, "eventListeners is null"); @@ -119,7 +124,7 @@ public ConnectorHandleResolver getHandleResolver() @Override public Connector create(String catalogName, Map config, ConnectorContext context) { - return new MockConnector(context, listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, getInsertLayout, getNewTableLayout, eventListeners, roleGrants); + return new MockConnector(context, listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, applyTopN, getInsertLayout, getNewTableLayout, eventListeners, roleGrants); } public static Builder builder() @@ -133,6 +138,12 @@ public interface ApplyProjection Optional> apply(ConnectorSession session, ConnectorTableHandle handle, List projections, Map assignments); } + @FunctionalInterface + public interface ApplyTopN + { + Optional> apply(ConnectorSession session, ConnectorTableHandle handle, long topNCount, List sortItems, Map assignments); + } + @FunctionalInterface public interface ListRoleGrants { @@ -149,6 +160,7 @@ public static class MockConnector private final BiFunction getTableHandle; private final Function> getColumns; private final ApplyProjection applyProjection; + private final ApplyTopN applyTopN; private final BiFunction> getInsertLayout; private final BiFunction> getNewTableLayout; private final Supplier> eventListeners; @@ -162,6 +174,7 @@ private MockConnector( BiFunction getTableHandle, Function> getColumns, ApplyProjection applyProjection, + ApplyTopN applyTopN, BiFunction> getInsertLayout, BiFunction> getNewTableLayout, Supplier> eventListeners, @@ -174,6 +187,7 @@ private MockConnector( this.getTableHandle = requireNonNull(getTableHandle, "getTableHandle is null"); this.getColumns = requireNonNull(getColumns, "getColumns is null"); this.applyProjection = requireNonNull(applyProjection, "applyProjection is null"); + this.applyTopN = requireNonNull(applyTopN, "applyTopN is null"); this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null"); this.getNewTableLayout = requireNonNull(getNewTableLayout, "getNewTableLayout is null"); this.eventListeners = requireNonNull(eventListeners, "eventListeners is null"); @@ -219,6 +233,12 @@ public Optional> applyProjecti return applyProjection.apply(session, handle, projections, assignments); } + @Override + public Optional> applyTopN(ConnectorSession session, ConnectorTableHandle handle, long topNCount, List sortItems, Map assignments) + { + return applyTopN.apply(session, handle, topNCount, sortItems, assignments); + } + @Override public List listSchemaNames(ConnectorSession session) { @@ -414,6 +434,7 @@ public static final class Builder private BiFunction> getNewTableLayout = defaultGetNewTableLayout(); private Supplier> eventListeners = ImmutableList::of; private ListRoleGrants roleGrants = defaultRoleAuthorizations(); + private ApplyTopN applyTopN = (session, handle, topNCount, sortItems, assignments) -> Optional.empty(); public Builder withListSchemaNames(Function> listSchemaNames) { @@ -457,6 +478,12 @@ public Builder withApplyProjection(ApplyProjection applyProjection) return this; } + public Builder withApplyTopN(ApplyTopN applyTopN) + { + this.applyTopN = applyTopN; + return this; + } + public Builder withGetInsertLayout(BiFunction> getInsertLayout) { this.getInsertLayout = requireNonNull(getInsertLayout, "getInsertLayout is null"); @@ -487,7 +514,7 @@ public Builder withEventListener(Supplier listenerFactory) public MockConnectorFactory build() { - return new MockConnectorFactory(listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, getInsertLayout, getNewTableLayout, eventListeners, roleGrants); + return new MockConnectorFactory(listSchemaNames, listTables, getViews, getTableHandle, getColumns, applyProjection, applyTopN, getInsertLayout, getNewTableLayout, eventListeners, roleGrants); } public static Function> defaultListSchemaNames() diff --git a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java index bb70ebc1f841..e03ac81f8d4a 100644 --- a/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/io/prestosql/metadata/AbstractMockMetadata.java @@ -40,7 +40,9 @@ import io.prestosql.spi.connector.LimitApplicationResult; import io.prestosql.spi.connector.ProjectionApplicationResult; import io.prestosql.spi.connector.SampleType; +import io.prestosql.spi.connector.SortItem; import io.prestosql.spi.connector.SystemTable; +import io.prestosql.spi.connector.TopNApplicationResult; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.function.OperatorType; @@ -741,4 +743,10 @@ public Optional> applyProjection(Sessio { return Optional.empty(); } + + @Override + public Optional> applyTopN(Session session, TableHandle handle, long topNFunctions, List sortItems, Map assignments) + { + return Optional.empty(); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNIntoTableScan.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNIntoTableScan.java new file mode 100644 index 000000000000..d025d969ecc4 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNIntoTableScan.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.Session; +import io.prestosql.connector.CatalogName; +import io.prestosql.connector.MockConnectorFactory; +import io.prestosql.connector.MockConnectorFactory.MockConnectorTableHandle; +import io.prestosql.metadata.TableHandle; +import io.prestosql.plugin.tpch.TpchColumnHandle; +import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.ColumnMetadata; +import io.prestosql.spi.connector.ConnectorTableHandle; +import io.prestosql.spi.connector.ConnectorTransactionHandle; +import io.prestosql.spi.connector.SchemaTableName; +import io.prestosql.spi.connector.TopNApplicationResult; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.iterative.rule.test.RuleTester; +import io.prestosql.sql.planner.plan.TopNNode; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Predicates.equalTo; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN; +import static io.prestosql.sql.planner.iterative.rule.test.RuleTester.defaultRuleTester; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; +import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; +import static io.prestosql.testing.TestingSession.testSessionBuilder; + +public class TestPushTopNIntoTableScan +{ + private static final String MOCK_CATALOG = "mock_catalog"; + private static final String TEST_SCHEMA = "test_schema"; + private static final String TEST_TABLE = "test_table"; + private static final SchemaTableName TEST_SCHEMA_TABLE = new SchemaTableName(TEST_SCHEMA, TEST_TABLE); + + private static final TableHandle TEST_TABLE_HANDLE = createTableHandle(new MockConnectorTableHandle(new SchemaTableName(TEST_SCHEMA, TEST_TABLE))); + + private static final Session MOCK_SESSION = testSessionBuilder().setCatalog(MOCK_CATALOG).setSchema(TEST_SCHEMA).build(); + + private static final String dimensionName = "dimension"; + private static final ColumnHandle dimensionColumn = new TpchColumnHandle(dimensionName, VARCHAR); + private static final String metricName = "metric"; + private static final ColumnHandle metricColumn = new TpchColumnHandle(metricName, BIGINT); + + private static final ImmutableMap assignments = ImmutableMap.of( + dimensionName, dimensionColumn, + metricName, metricColumn); + + private static TableHandle createTableHandle(ConnectorTableHandle tableHandle) + { + return new TableHandle( + new CatalogName(MOCK_CATALOG), + tableHandle, + new ConnectorTransactionHandle() {}, + Optional.empty()); + } + + @Test + public void testDoesNotFire() + { + try (RuleTester ruleTester = defaultRuleTester()) { + MockConnectorFactory mockFactory = createMockFactory(assignments, Optional.empty()); + ruleTester.getQueryRunner().createCatalog(MOCK_CATALOG, mockFactory, ImmutableMap.of()); + + ruleTester.assertThat(new PushTopNIntoTableScan(ruleTester.getMetadata())) + .on(p -> { + Symbol dimension = p.symbol(dimensionName, VARCHAR); + Symbol metric = p.symbol(metricName, BIGINT); + return p.topN(1, ImmutableList.of(dimension), + p.tableScan(TEST_TABLE_HANDLE, + ImmutableList.of(dimension, metric), + ImmutableMap.of( + dimension, dimensionColumn, + metric, metricColumn))); + }) + .withSession(MOCK_SESSION) + .doesNotFire(); + } + } + + @Test + public void testPushTopNIntoTableScan() + { + try (RuleTester ruleTester = defaultRuleTester()) { + MockConnectorTableHandle connectorHandle = new MockConnectorTableHandle(TEST_SCHEMA_TABLE); + // make the mock connector return a new connectorHandle + MockConnectorFactory.ApplyTopN applyTopN = + (session, handle, topNCount, sortItems, tableAssignments) -> Optional.of(new TopNApplicationResult<>(connectorHandle, true)); + MockConnectorFactory mockFactory = createMockFactory(assignments, Optional.of(applyTopN)); + + ruleTester.getQueryRunner().createCatalog(MOCK_CATALOG, mockFactory, ImmutableMap.of()); + + ruleTester.assertThat(new PushTopNIntoTableScan(ruleTester.getMetadata())) + .on(p -> { + Symbol dimension = p.symbol(dimensionName, VARCHAR); + Symbol metric = p.symbol(metricName, BIGINT); + return p.topN(1, ImmutableList.of(dimension), + p.tableScan(TEST_TABLE_HANDLE, + ImmutableList.of(dimension, metric), + ImmutableMap.of( + dimension, dimensionColumn, + metric, metricColumn))); + }) + .withSession(MOCK_SESSION) + .matches( + tableScan( + equalTo(connectorHandle), + TupleDomain.all(), + new HashMap<>())); + } + } + + @Test + public void testPushTopNIntoTableScanPartial() + { + try (RuleTester ruleTester = defaultRuleTester()) { + MockConnectorTableHandle connectorHandle = new MockConnectorTableHandle(TEST_SCHEMA_TABLE); + // make the mock connector return a new connectorHandle + MockConnectorFactory.ApplyTopN applyTopN = + (session, handle, topNCount, sortItems, tableAssignments) -> Optional.of(new TopNApplicationResult<>(connectorHandle, false)); + MockConnectorFactory mockFactory = createMockFactory(assignments, Optional.of(applyTopN)); + + ruleTester.getQueryRunner().createCatalog(MOCK_CATALOG, mockFactory, ImmutableMap.of()); + + ruleTester.assertThat(new PushTopNIntoTableScan(ruleTester.getMetadata())) + .on(p -> { + Symbol dimension = p.symbol(dimensionName, VARCHAR); + Symbol metric = p.symbol(metricName, BIGINT); + return p.topN(1, ImmutableList.of(dimension), + p.tableScan(TEST_TABLE_HANDLE, + ImmutableList.of(dimension, metric), + ImmutableMap.of( + dimension, dimensionColumn, + metric, metricColumn))); + }) + .withSession(MOCK_SESSION) + .matches( + topN(1, ImmutableList.of(sort(dimensionName, ASCENDING, FIRST)), + TopNNode.Step.FINAL, + tableScan( + equalTo(connectorHandle), + TupleDomain.all(), + ImmutableMap.of( + dimensionName, equalTo(dimensionColumn), + metricName, equalTo(metricColumn))))); + } + } + + private MockConnectorFactory createMockFactory(Map assignments, Optional applyTopN) + { + List metadata = assignments.entrySet().stream() + .map(entry -> new ColumnMetadata(entry.getKey(), ((TpchColumnHandle) entry.getValue()).getType())) + .collect(toImmutableList()); + + MockConnectorFactory.Builder builder = MockConnectorFactory.builder() + .withListSchemaNames(connectorSession -> ImmutableList.of(TEST_SCHEMA)) + .withListTables((connectorSession, schema) -> TEST_SCHEMA.equals(schema) ? ImmutableList.of(TEST_SCHEMA_TABLE) : ImmutableList.of()) + .withGetColumns(schemaTableName -> metadata); + + if (applyTopN.isPresent()) { + builder = builder.withApplyTopN(applyTopN.get()); + } + + return builder.build(); + } +} diff --git a/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index e61c22ea10d6..248b38eb7a74 100644 --- a/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-plugin-toolkit/src/main/java/io/prestosql/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -42,7 +42,9 @@ import io.prestosql.spi.connector.SampleType; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.SchemaTablePrefix; +import io.prestosql.spi.connector.SortItem; import io.prestosql.spi.connector.SystemTable; +import io.prestosql.spi.connector.TopNApplicationResult; import io.prestosql.spi.expression.ConnectorExpression; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.security.GrantInfo; @@ -718,6 +720,19 @@ public Optional> applyAggrega } } + @Override + public Optional> applyTopN( + ConnectorSession session, + ConnectorTableHandle table, + long topNCount, + List sortItems, + Map assignments) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.applyTopN(session, table, topNCount, sortItems, assignments); + } + } + @Override public void validateScan(ConnectorSession session, ConnectorTableHandle handle) { diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java index 68ff879128ec..4081b899e795 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/ConnectorMetadata.java @@ -933,6 +933,30 @@ default Optional> applyAggreg return Optional.empty(); } + /** + * Attempt to push down the TopN into the table scan. + *

+ * Connectors can indicate whether they don't support topN pushdown or that the action had no effect + * by returning {@link Optional#empty()}. Connectors should expect this method may be called multiple times. + *

+ * Note: it's critical for connectors to return {@link Optional#empty()} if calling this method has no effect for that + * invocation, even if the connector generally supports topN pushdown. Doing otherwise can cause the optimizer + * to loop indefinitely. + *

+ * If the connector can handle TopN Pushdown and guarantee it will produce fewer rows than it should return a + * non-empty result with "topN guaranteed" flag set to true. + * @return + */ + default Optional> applyTopN( + ConnectorSession session, + ConnectorTableHandle handle, + long topNCount, + List sortItems, + Map assignments) + { + return Optional.empty(); + } + /** * Allows the connector to reject the table scan produced by the planner. *

diff --git a/presto-spi/src/main/java/io/prestosql/spi/connector/TopNApplicationResult.java b/presto-spi/src/main/java/io/prestosql/spi/connector/TopNApplicationResult.java new file mode 100644 index 000000000000..089a34b234e5 --- /dev/null +++ b/presto-spi/src/main/java/io/prestosql/spi/connector/TopNApplicationResult.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.spi.connector; + +import static java.util.Objects.requireNonNull; + +public class TopNApplicationResult +{ + private final T handle; + private final boolean topNGuaranteed; + + public TopNApplicationResult(T handle, boolean topNGuaranteed) + { + this.handle = requireNonNull(handle, "handle is null"); + this.topNGuaranteed = topNGuaranteed; + } + + public T getHandle() + { + return handle; + } + + public boolean isTopNGuaranteed() + { + return topNGuaranteed; + } +}