From 67d2df9ac15f8169a492dd8e9ccc44fb879424f5 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Mon, 31 Jan 2022 12:12:20 +0530 Subject: [PATCH] Consider correlation for output estimation of filter conjuncts Currently we assume that there is no correlation between the terms of a filter conjunction. This can result in underestimation as there is often some correlation between columns in real data sets. In particular, predicates inferred on the build side relation through a join with a partitioned table are often correlated with user provided predicates on the build side. Estimation for filter conjunctions now applies an exponential decay to the selectivity of each successive term to reduce chances of under estimation. optimizer.filter-conjunction-independence-factor is added to allow tuning the strength of the independence assumption. --- .../io/trino/SystemSessionProperties.java | 27 ++++ .../io/trino/cost/FilterStatsCalculator.java | 136 +++++++++++++--- .../trino/cost/PlanNodeStatsEstimateMath.java | 86 ++++++++++ .../io/trino/sql/planner/OptimizerConfig.java | 17 ++ .../trino/cost/TestFilterStatsCalculator.java | 148 +++++++++++++++++- .../io/trino/cost/TestOptimizerConfig.java | 3 + .../sphinx/admin/properties-optimizer.rst | 15 ++ .../io/trino/plugin/hive/TestShowStats.java | 10 +- .../resources/sql/presto/tpcds/q19.plan.txt | 28 ++-- .../resources/sql/presto/tpcds/q49.plan.txt | 26 +-- .../resources/sql/presto/tpcds/q52.plan.txt | 4 +- .../resources/sql/presto/tpcds/q55.plan.txt | 4 +- .../resources/sql/presto/tpcds/q68.plan.txt | 24 +-- .../resources/sql/presto/tpcds/q85.plan.txt | 40 ++--- 14 files changed, 472 insertions(+), 96 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index ae71e08703c3..e84bf9f5e401 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -46,6 +46,7 @@ import static io.trino.spi.session.PropertyMetadata.integerProperty; import static io.trino.spi.session.PropertyMetadata.longProperty; import static io.trino.spi.session.PropertyMetadata.stringProperty; +import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; import static java.lang.Math.min; @@ -122,6 +123,7 @@ public final class SystemSessionProperties public static final String IGNORE_STATS_CALCULATOR_FAILURES = "ignore_stats_calculator_failures"; public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task"; public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled"; + public static final String FILTER_CONJUNCTION_INDEPENDENCE_FACTOR = "filter_conjunction_independence_factor"; public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort"; public static final String ALLOW_PUSHDOWN_INTO_CONNECTORS = "allow_pushdown_into_connectors"; public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown"; @@ -557,6 +559,15 @@ public SystemSessionProperties( "use a default filter factor for unknown filters in a filter node", optimizerConfig.isDefaultFilterFactorEnabled(), false), + new PropertyMetadata<>( + FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, + "Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters", + DOUBLE, + Double.class, + optimizerConfig.getFilterConjunctionIndependenceFactor(), + false, + value -> validateDoubleRange(value, FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, 0.0, 1.0), + value -> value), booleanProperty( SKIP_REDUNDANT_SORT, "Skip redundant sort operations", @@ -1134,6 +1145,17 @@ private static Integer validateIntegerValue(Object value, String property, int l return intValue; } + private static double validateDoubleRange(Object value, String property, double lowerBoundIncluded, double upperBoundIncluded) + { + double doubleValue = (double) value; + if (doubleValue < lowerBoundIncluded || doubleValue > upperBoundIncluded) { + throw new TrinoException( + INVALID_SESSION_PROPERTY, + format("%s must be in the range [%.2f, %.2f]: %.2f", property, lowerBoundIncluded, upperBoundIncluded, doubleValue)); + } + return doubleValue; + } + public static boolean isStatisticsCpuTimerEnabled(Session session) { return session.getSystemProperty(STATISTICS_CPU_TIMER_ENABLED, Boolean.class); @@ -1164,6 +1186,11 @@ public static boolean isDefaultFilterFactorEnabled(Session session) return session.getSystemProperty(DEFAULT_FILTER_FACTOR_ENABLED, Boolean.class); } + public static double getFilterConjunctionIndependenceFactor(Session session) + { + return session.getSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, Double.class); + } + public static boolean isSkipRedundantSort(Session session) { return session.getSystemProperty(SKIP_REDUNDANT_SORT, Boolean.class); diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index 1f9b04e990b9..f80055075ea2 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -14,8 +14,10 @@ package io.trino.cost; import com.google.common.base.VerifyException; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.security.AllowAllAccessControl; @@ -44,23 +46,29 @@ import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.SymbolReference; +import io.trino.util.DisjointSet; import javax.annotation.Nullable; import javax.inject.Inject; -import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalDouble; +import java.util.Set; +import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.SystemSessionProperties.getFilterConjunctionIndependenceFactor; import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison; import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison; import static io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues; import static io.trino.cost.PlanNodeStatsEstimateMath.capStats; +import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount; +import static io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats; import static io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats; import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -68,6 +76,7 @@ import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.ExpressionUtils.getExpressionTypes; import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; +import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; @@ -137,7 +146,14 @@ private class FilterExpressionStatsCalculatingVisitor @Override public PlanNodeStatsEstimate process(Node node, @Nullable Void context) { - return normalizer.normalize(super.process(node, context), types); + PlanNodeStatsEstimate output; + if (input.getOutputRowCount() == 0 || input.isOutputRowCountUnknown()) { + output = input; + } + else { + output = super.process(node, context); + } + return normalizer.normalize(output, types); } @Override @@ -169,35 +185,56 @@ protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, V private PlanNodeStatsEstimate estimateLogicalAnd(List terms) { - // first try to estimate in the fair way - PlanNodeStatsEstimate estimate = process(terms.get(0)); - if (!estimate.isOutputRowCountUnknown()) { - for (int i = 1; i < terms.size(); i++) { - estimate = new FilterExpressionStatsCalculatingVisitor(estimate, session, types).process(terms.get(i)); + double filterConjunctionIndependenceFactor = getFilterConjunctionIndependenceFactor(session); + List estimates = estimateCorrelatedExpressions(terms, filterConjunctionIndependenceFactor); + double outputRowCount = estimateCorrelatedConjunctionRowCount( + input, + estimates, + filterConjunctionIndependenceFactor); + if (isNaN(outputRowCount)) { + return PlanNodeStatsEstimate.unknown(); + } + return normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, intersectCorrelatedStats(estimates)), types); + } - if (estimate.isOutputRowCountUnknown()) { - break; + /** + * There can be multiple predicate expressions for the same symbol, e.g. x > 0 AND x <= 1, x BETWEEN 1 AND 10. + * We attempt to detect such cases in extractCorrelatedGroups and calculate a combined estimate for each + * such group of expressions. This is done so that we don't apply the above scaling factors when combining estimates + * from conjunction of multiple predicates on the same symbol and underestimate the output. + **/ + private List estimateCorrelatedExpressions(List terms, double filterConjunctionIndependenceFactor) + { + ImmutableList.Builder estimatesBuilder = ImmutableList.builder(); + boolean hasUnestimatedTerm = false; + for (List correlatedExpressions : extractCorrelatedGroups(terms, filterConjunctionIndependenceFactor)) { + PlanNodeStatsEstimate combinedEstimate = PlanNodeStatsEstimate.unknown(); + for (Expression expression : correlatedExpressions) { + PlanNodeStatsEstimate estimate; + // combinedEstimate is unknown until the 1st known estimated term + if (combinedEstimate.isOutputRowCountUnknown()) { + estimate = process(expression); + } + else { + estimate = new FilterExpressionStatsCalculatingVisitor(combinedEstimate, session, types) + .process(expression); } - } - if (!estimate.isOutputRowCountUnknown()) { - return estimate; + if (estimate.isOutputRowCountUnknown()) { + hasUnestimatedTerm = true; + } + else { + // update combinedEstimate only when the term estimate is known so that all the known estimates + // can be applied progressively through FilterExpressionStatsCalculatingVisitor calls. + combinedEstimate = estimate; + } } + estimatesBuilder.add(combinedEstimate); } - - // If some of the filters cannot be estimated, take the smallest estimate. - // Apply 0.9 filter factor as "unknown filter" factor. - Optional smallest = terms.stream() - .map(this::process) - .filter(termEstimate -> !termEstimate.isOutputRowCountUnknown()) - .sorted(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)) - .findFirst(); - - if (smallest.isEmpty()) { - return PlanNodeStatsEstimate.unknown(); + if (hasUnestimatedTerm) { + estimatesBuilder.add(PlanNodeStatsEstimate.unknown()); } - - return smallest.get().mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT); + return estimatesBuilder.build(); } private PlanNodeStatsEstimate estimateLogicalOr(List terms) @@ -442,4 +479,53 @@ private OptionalDouble doubleValueFromLiteral(Type type, Expression literal) return toStatsRepresentation(type, literalValue); } } + + private static List> extractCorrelatedGroups(List terms, double filterConjunctionIndependenceFactor) + { + if (filterConjunctionIndependenceFactor == 1) { + // Allows the filters to be estimated as if there is no correlation between any of the terms + return ImmutableList.of(terms); + } + + ListMultimap expressionUniqueSymbols = ArrayListMultimap.create(); + terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression))); + // Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets + // do not appear together in any expression. + DisjointSet symbolsPartitioner = new DisjointSet<>(); + for (Expression term : terms) { + List expressionSymbols = expressionUniqueSymbols.get(term); + if (expressionSymbols.isEmpty()) { + continue; + } + // Ensure that symbol is added to DisjointSet when there is only one symbol in the list + symbolsPartitioner.find(expressionSymbols.get(0)); + for (int i = 1; i < expressionSymbols.size(); i++) { + symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i)); + } + } + + // Use disjoint sets of symbols to partition the given list of expressions + List> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses()); + checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions"); + ListMultimap expressionPartitions = ArrayListMultimap.create(); + for (Expression term : terms) { + List expressionSymbols = expressionUniqueSymbols.get(term); + int expressionPartitionId; + if (expressionSymbols.isEmpty()) { + expressionPartitionId = symbolPartitions.size(); // For expressions with no symbols + } + else { + Symbol symbol = expressionSymbols.get(0); // Lookup any symbol to find the partition id + expressionPartitionId = IntStream.range(0, symbolPartitions.size()) + .filter(partition -> symbolPartitions.get(partition).contains(symbol)) + .findFirst() + .orElseThrow(); + } + expressionPartitions.put(expressionPartitionId, term); + } + + return expressionPartitions.keySet().stream() + .map(expressionPartitions::get) + .collect(toImmutableList()); + } } diff --git a/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimateMath.java b/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimateMath.java index 05b09f8e0b4c..b692884a6460 100644 --- a/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimateMath.java +++ b/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimateMath.java @@ -13,11 +13,20 @@ */ package io.trino.cost; +import io.trino.sql.planner.Symbol; +import io.trino.util.MoreMath; + +import java.util.List; +import java.util.Map; + import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static java.lang.Double.NaN; import static java.lang.Double.isNaN; import static java.lang.Double.max; import static java.lang.Double.min; +import static java.util.Comparator.comparingDouble; import static java.util.stream.Stream.concat; public final class PlanNodeStatsEstimateMath @@ -135,6 +144,83 @@ public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNo return result.build(); } + public static Map intersectCorrelatedStats(List estimates) + { + checkArgument(!estimates.isEmpty(), "estimates is empty"); + if (estimates.size() == 1) { + return estimates.get(0).getSymbolStatistics(); + } + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); + // Update statistic range for symbols + estimates.stream().flatMap(estimate -> estimate.getSymbolsWithKnownStatistics().stream()) + .distinct() + .forEach(symbol -> { + List symbolStatsEstimates = estimates.stream() + .map(estimate -> estimate.getSymbolStatistics(symbol)) + .collect(toImmutableList()); + + StatisticRange intersect = symbolStatsEstimates.stream() + .map(StatisticRange::from) + .reduce(StatisticRange::intersect) + .orElseThrow(); + + // intersectCorrelatedStats should try to produce stats as if filters are applied in sequence. + // Using min works for filters like (a > 10 AND b < 10), but won't work for + // (a > 10 AND b IS NULL). However, former case is more common. + double nullsFraction = symbolStatsEstimates.stream() + .map(SymbolStatsEstimate::getNullsFraction) + .reduce(MoreMath::minExcludeNaN) + .orElseThrow(); + + double averageRowSize = symbolStatsEstimates.stream() + .map(SymbolStatsEstimate::getAverageRowSize) + .reduce(MoreMath::averageExcludingNaNs) + .orElseThrow(); + + result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder() + .setStatisticsRange(intersect) + .setNullsFraction(nullsFraction) + .setAverageRowSize(averageRowSize) + .build()); + }); + return result.build().getSymbolStatistics(); + } + + public static double estimateCorrelatedConjunctionRowCount( + PlanNodeStatsEstimate input, + List estimates, + double independenceFactor) + { + checkArgument(!estimates.isEmpty(), "estimates is empty"); + if (input.isOutputRowCountUnknown() || input.getOutputRowCount() == 0) { + return input.getOutputRowCount(); + } + List knownSortedEstimates = estimates.stream() + .filter(estimateInfo -> !estimateInfo.isOutputRowCountUnknown()) + .sorted(comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)) + .collect(toImmutableList()); + if (knownSortedEstimates.isEmpty()) { + return NaN; + } + + PlanNodeStatsEstimate combinedEstimate = knownSortedEstimates.get(0); + double combinedSelectivity = combinedEstimate.getOutputRowCount() / input.getOutputRowCount(); + double combinedIndependenceFactor = 1.0; + // For independenceFactor = 0.75 and terms t1, t2, t3 + // Combined selectivity = (t1 selectivity) * ((t2 selectivity) ^ 0.75) * ((t3 selectivity) ^ (0.75 * 0.75)) + // independenceFactor = 1 implies the terms are assumed to have no correlation and their selectivities are multiplied without scaling. + // independenceFactor = 0 implies the terms are assumed to be fully correlated and only the most selective term drives the selectivity. + for (int i = 1; i < knownSortedEstimates.size(); i++) { + PlanNodeStatsEstimate term = knownSortedEstimates.get(i); + combinedIndependenceFactor *= independenceFactor; + combinedSelectivity *= Math.pow(term.getOutputRowCount() / input.getOutputRowCount(), combinedIndependenceFactor); + } + double outputRowCount = input.getOutputRowCount() * combinedSelectivity; + // TODO use UNKNOWN_FILTER_COEFFICIENT only when default-filter-factor is enabled + boolean hasUnestimatedTerm = estimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown); + return hasUnestimatedTerm ? outputRowCount * UNKNOWN_FILTER_COEFFICIENT : outputRowCount; + } + private static PlanNodeStatsEstimate createZeroStats(PlanNodeStatsEstimate stats) { PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java index 9cc9ff382578..e41985d4176b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java @@ -19,6 +19,7 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; +import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; @@ -43,6 +44,7 @@ public class OptimizerConfig private boolean collectPlanStatisticsForAllQueries; private boolean ignoreStatsCalculatorFailures = true; private boolean defaultFilterFactorEnabled; + private double filterConjunctionIndependenceFactor = 0.75; private boolean colocatedJoinsEnabled; private boolean distributedIndexJoinsEnabled; @@ -258,6 +260,21 @@ public OptimizerConfig setDefaultFilterFactorEnabled(boolean defaultFilterFactor return this; } + @Min(0) + @Max(1) + public double getFilterConjunctionIndependenceFactor() + { + return filterConjunctionIndependenceFactor; + } + + @Config("optimizer.filter-conjunction-independence-factor") + @ConfigDescription("Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters") + public OptimizerConfig setFilterConjunctionIndependenceFactor(double filterConjunctionIndependenceFactor) + { + this.filterConjunctionIndependenceFactor = filterConjunctionIndependenceFactor; + return this; + } + public boolean isColocatedJoinsEnabled() { return colocatedJoinsEnabled; diff --git a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java index 7e2c0c9aab2e..b814aa82a3b7 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestFilterStatsCalculator.java @@ -27,6 +27,9 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.function.Consumer; + +import static io.trino.SystemSessionProperties.FILTER_CONJUNCTION_INDEPENDENCE_FACTOR; import static io.trino.sql.ExpressionTestUtils.planExpression; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; @@ -257,6 +260,13 @@ public void testUnsupportedExpression() @Test public void testAndStats() { + // unknown input + assertExpression("x < 0e0 AND x < 1e0", PlanNodeStatsEstimate.unknown()).outputRowsCountUnknown(); + assertExpression("x < 0e0 AND y < 1e0", PlanNodeStatsEstimate.unknown()).outputRowsCountUnknown(); + // zeroStatistics input + assertExpression("x < 0e0 AND x < 1e0", zeroStatistics).equalTo(zeroStatistics); + assertExpression("x < 0e0 AND y < 1e0", zeroStatistics).equalTo(zeroStatistics); + assertExpression("x < 0e0 AND x > 1e0").equalTo(zeroStatistics); assertExpression("x < 0e0 AND x > DOUBLE '-7.5'") @@ -301,6 +311,128 @@ public void testAndStats() assertExpression("CAST(NULL AS boolean) AND CAST(NULL AS boolean)").equalTo(zeroStatistics); assertExpression("CAST(NULL AS boolean) AND (x < 0e0 AND x > 1e0)").equalTo(zeroStatistics); + + Consumer symbolAssertX = symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(-5.0) + .highValue(5.0) + .distinctValuesCount(20.0) + .nullsFraction(0.0); + Consumer symbolAssertY = symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(1.0) + .highValue(5.0) + .distinctValuesCount(16.0) + .nullsFraction(0.0); + + double inputRowCount = standardInputStatistics.getOutputRowCount(); + double filterSelectivityX = 0.375; + double inequalityFilterSelectivityY = 0.4; + assertExpression( + "(x BETWEEN -5 AND 5) AND y > 1", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0").build()) + .outputRowsCount(filterSelectivityX * inputRowCount) + .symbolStats("x", symbolAssertX) + .symbolStats("y", symbolAssertY); + + assertExpression( + "(x BETWEEN -5 AND 5) AND y > 1", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "1").build()) + .outputRowsCount(filterSelectivityX * inequalityFilterSelectivityY * inputRowCount) + .symbolStats("x", symbolAssertX) + .symbolStats("y", symbolAssertY); + + assertExpression( + "(x BETWEEN -5 AND 5) AND y > 1", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(filterSelectivityX * (Math.pow(inequalityFilterSelectivityY, 0.5)) * inputRowCount) + .symbolStats("x", symbolAssertX) + .symbolStats("y", symbolAssertY); + + double nullFilterSelectivityY = 0.5; + assertExpression( + "(x BETWEEN -5 AND 5) AND y IS NULL", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "1").build()) + .outputRowsCount(filterSelectivityX * nullFilterSelectivityY * inputRowCount) + .symbolStats("x", symbolAssertX) + .symbolStats("y", symbolAssert -> symbolAssert.isEqualTo(SymbolStatsEstimate.zero())); + + assertExpression( + "(x BETWEEN -5 AND 5) AND y IS NULL", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(filterSelectivityX * Math.pow(nullFilterSelectivityY, 0.5) * inputRowCount) + .symbolStats("x", symbolAssertX) + .symbolStats("y", symbolAssert -> symbolAssert.isEqualTo(SymbolStatsEstimate.zero())); + + assertExpression( + "(x BETWEEN -5 AND 5) AND y IS NULL", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0").build()) + .outputRowsCount(filterSelectivityX * inputRowCount) + .symbolStats("x", symbolAssertX) + .symbolStats("y", symbolAssert -> symbolAssert.isEqualTo(SymbolStatsEstimate.zero())); + + assertExpression( + "y < 1 AND 0 < y", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(100) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(1.0) + .distinctValuesCount(4.0) + .nullsFraction(0.0)); + + assertExpression( + "x > 0 AND (y < 1 OR y > 2)", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(filterSelectivityX * (Math.pow(inequalityFilterSelectivityY, 0.5)) * inputRowCount) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(10.0) + .distinctValuesCount(20.0) + .nullsFraction(0.0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(5.0) + .distinctValuesCount(16.0) + .nullsFraction(0.0)); + + assertExpression( + "x > 0 AND (x < 1 OR y > 1)", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(172.0) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(10.0) + .distinctValuesCount(20.0) + .nullsFraction(0.0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(5.0) + .distinctValuesCount(20.0) + .nullsFraction(0.1053779069)); + + assertExpression( + "x IN (0, 1, 2) AND (x = 0 OR (x = 1 AND y = 1) OR (x = 2 AND y = 1))", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(20.373798) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(2.0) + .distinctValuesCount(2.623798) + .nullsFraction(0.0)) + .symbolStats("y", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(5.0) + .distinctValuesCount(15.686298) + .nullsFraction(0.2300749269)); + + assertExpression( + "x > 0 AND CAST(NULL AS boolean)", + Session.builder(session).setSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, "0.5").build()) + .outputRowsCount(filterSelectivityX * inputRowCount * 0.9) + .symbolStats("x", symbolAssert -> symbolAssert.averageRowSize(4.0) + .lowValue(0.0) + .highValue(10.0) + .distinctValuesCount(20.0) + .nullsFraction(0.0)); } @Test @@ -573,16 +705,26 @@ public void testInPredicateFilter() private PlanNodeStatsAssertion assertExpression(String expression) { - return assertExpression(planExpression(PLANNER_CONTEXT, session, standardTypes, expression(expression))); + return assertExpression(expression, session); + } + + private PlanNodeStatsAssertion assertExpression(String expression, PlanNodeStatsEstimate inputStatistics) + { + return assertExpression(planExpression(PLANNER_CONTEXT, session, standardTypes, expression(expression)), session, inputStatistics); + } + + private PlanNodeStatsAssertion assertExpression(String expression, Session session) + { + return assertExpression(planExpression(PLANNER_CONTEXT, session, standardTypes, expression(expression)), session, standardInputStatistics); } - private PlanNodeStatsAssertion assertExpression(Expression expression) + private PlanNodeStatsAssertion assertExpression(Expression expression, Session session, PlanNodeStatsEstimate inputStatistics) { return transaction(new TestingTransactionManager(), new AllowAllAccessControl()) .singleStatement() .execute(session, transactionSession -> { return PlanNodeStatsAssertion.assertThat(statsCalculator.filterStats( - standardInputStatistics, + inputStatistics, expression, transactionSession, standardTypes)); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java b/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java index edc6ec7e1608..f513d3a68e0e 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestOptimizerConfig.java @@ -56,6 +56,7 @@ public void testDefaults() .setCollectPlanStatisticsForAllQueries(false) .setIgnoreStatsCalculatorFailures(true) .setDefaultFilterFactorEnabled(false) + .setFilterConjunctionIndependenceFactor(0.75) .setOptimizeMetadataQueries(false) .setOptimizeHashGeneration(true) .setPushTableWriteThroughUnion(true) @@ -98,6 +99,7 @@ public void testExplicitPropertyMappings() .put("collect-plan-statistics-for-all-queries", "true") .put("optimizer.ignore-stats-calculator-failures", "false") .put("optimizer.default-filter-factor-enabled", "true") + .put("optimizer.filter-conjunction-independence-factor", "1.0") .put("join-distribution-type", "BROADCAST") .put("join-max-broadcast-table-size", "42GB") .put("optimizer.join-reordering-strategy", "NONE") @@ -157,6 +159,7 @@ public void testExplicitPropertyMappings() .setUsePreferredWritePartitioning(false) .setPreferredWritePartitioningMinNumberOfPartitions(10) .setDefaultFilterFactorEnabled(true) + .setFilterConjunctionIndependenceFactor(1.0) .setOptimizeMetadataQueries(true) .setOptimizeHashGeneration(false) .setOptimizeMixedDistinctAggregations(true) diff --git a/docs/src/main/sphinx/admin/properties-optimizer.rst b/docs/src/main/sphinx/admin/properties-optimizer.rst index 974b81a362db..b4eef30a21a9 100644 --- a/docs/src/main/sphinx/admin/properties-optimizer.rst +++ b/docs/src/main/sphinx/admin/properties-optimizer.rst @@ -146,3 +146,18 @@ Specifies minimal bucket to task ratio that has to be matched or exceeded in ord to use table scan node partitioning. When the table bucket count is small compared to the number of workers, then the table scan is distributed across all workers for improved parallelism. + +``optimizer.filter-conjunction-independence-factor`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** :ref:`prop-type-double` +* **Default value:** ``0.75`` +* **Min allowed value:** ``0`` +* **Max allowed value:** ``1`` + +Scales the strength of independence assumption for estimating the selectivity of +the conjunction of multiple predicates. Lower values for this property will produce +more conservative estimates by assuming a greater degree of correlation between the +columns of the predicates in a conjunction. A value of ``0`` results in the +optimizer assuming that the columns of the predicates are fully correlated and only +the most selective predicate drives the selectivity of a conjunction of predicates. diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java index 539263829a71..8f2222b668f6 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java @@ -704,11 +704,11 @@ public void testShowStatsWithView() assertQuery( "SHOW STATS FOR (SELECT * FROM nation_view WHERE regionkey = 0)", "VALUES " + - " ('nationkey', null, 1, 0, null, 0, 24), " + - " ('name', 7.08, 1, 0, null, null, null), " + - " ('comment', 74.28, 1, 0, null, null, null), " + - " ('regionkey', null, 1, 0, null, 0, 0), " + - " (null, null, null, null, 1, null, null)"); + " ('nationkey', null, 0.29906975624424414, 0, null, 0, 24), " + + " ('name', 2.1174138742092485, 0.29906975624424414, 0, null, null, null), " + + " ('comment', 22.214901493822456, 0.29906975624424414, 0, null, null, null), " + + " ('regionkey', null, 0.29906975624424414, 0, null, 0, 0), " + + " (null, null, null, null, 0.29906975624424414, null, null)"); assertUpdate("DROP VIEW nation_view"); } diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q19.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q19.plan.txt index c500acf6fca3..5b7f87c8d34c 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q19.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q19.plan.txt @@ -6,24 +6,24 @@ local exchange (GATHER, SINGLE, []) partial aggregation over (i_brand, i_brand_id, i_manufact, i_manufact_id) join (INNER, REPLICATED): join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["ca_address_sk"]) - scan customer_address - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["c_current_addr_sk"]) - join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["c_customer_sk"]) - scan customer - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) + remote exchange (REPARTITION, HASH, ["c_current_addr_sk"]) + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, ["c_customer_sk"]) + scan customer + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) + join (INNER, REPLICATED): join (INNER, REPLICATED): - join (INNER, REPLICATED): - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim + scan store_sales local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan item + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["ca_address_sk"]) + scan customer_address local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan store diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q49.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q49.plan.txt index 278d8dae514d..3ac14bb8f136 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q49.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q49.plan.txt @@ -38,21 +38,21 @@ local exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan date_dim - remote exchange (REPARTITION, HASH, ["expr_104", "expr_99", "rank_101", "rank_102", "sr_item_sk"]) - partial aggregation over (expr_104, expr_99, rank_101, rank_102, sr_item_sk) + remote exchange (REPARTITION, HASH, ["expr_104", "expr_99", "rank_101", "rank_102", "ss_item_sk"]) + partial aggregation over (expr_104, expr_99, rank_101, rank_102, ss_item_sk) local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) - final aggregation over (sr_item_sk) + final aggregation over (ss_item_sk) local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["sr_item_sk"]) - partial aggregation over (sr_item_sk) + remote exchange (REPARTITION, HASH, ["ss_item_sk"]) + partial aggregation over (ss_item_sk) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["sr_item_sk", "sr_ticket_number"]) - scan store_returns + remote exchange (REPARTITION, HASH, ["ss_item_sk", "ss_ticket_number"]) + join (INNER, REPLICATED): + scan store_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ss_item_sk", "ss_ticket_number"]) - join (INNER, REPLICATED): - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim + remote exchange (REPARTITION, HASH, ["sr_item_sk", "sr_ticket_number"]) + scan store_returns diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q52.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q52.plan.txt index 33752e693e6c..70aae3fa1bdd 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q52.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q52.plan.txt @@ -9,7 +9,7 @@ local exchange (GATHER, SINGLE, []) scan store_sales local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan date_dim + scan item local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan item + scan date_dim diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q55.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q55.plan.txt index 0322599f01e5..f89f2f078b63 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q55.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q55.plan.txt @@ -9,7 +9,7 @@ local exchange (GATHER, SINGLE, []) scan store_sales local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan date_dim + scan item local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan item + scan date_dim diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q68.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q68.plan.txt index 2ef8a760155b..46b2fc2f8db8 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q68.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q68.plan.txt @@ -7,27 +7,27 @@ local exchange (GATHER, SINGLE, []) scan customer local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ss_customer_sk"]) - final aggregation over (ca_address_sk, ca_city, ss_customer_sk, ss_ticket_number) + final aggregation over (ca_city, ss_addr_sk, ss_customer_sk, ss_ticket_number) local exchange (GATHER, SINGLE, []) - partial aggregation over (ca_address_sk, ca_city, ss_customer_sk, ss_ticket_number) + partial aggregation over (ca_city, ss_addr_sk, ss_customer_sk, ss_ticket_number) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["ca_address_sk"]) - scan customer_address - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ss_addr_sk"]) + remote exchange (REPARTITION, HASH, ["ss_addr_sk"]) + join (INNER, REPLICATED): join (INNER, REPLICATED): join (INNER, REPLICATED): - join (INNER, REPLICATED): - scan store_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim + scan store_sales local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan store local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) - scan household_demographics + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan household_demographics + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["ca_address_sk"]) + scan customer_address local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, ["ca_address_sk_2"]) scan customer_address diff --git a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q85.plan.txt b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q85.plan.txt index 209c0a446935..8a4986ba474a 100644 --- a/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q85.plan.txt +++ b/testing/trino-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q85.plan.txt @@ -7,29 +7,29 @@ local exchange (GATHER, SINGLE, []) join (INNER, REPLICATED): join (INNER, REPLICATED): join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["cd_demo_sk_0", "cd_education_status_3", "cd_marital_status_2"]) - scan customer_demographics - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["cd_education_status", "cd_marital_status", "wr_returning_cdemo_sk"]) + remote exchange (REPARTITION, HASH, ["cd_education_status", "cd_marital_status", "wr_returning_cdemo_sk"]) + join (INNER, REPLICATED): join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["wr_refunded_addr_sk"]) - join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, ["ws_item_sk", "ws_order_number"]) - join (INNER, REPLICATED): - scan web_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim + remote exchange (REPARTITION, HASH, ["ws_item_sk", "ws_order_number"]) + join (INNER, REPLICATED): + scan web_sales local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["wr_item_sk", "wr_order_number"]) - join (INNER, REPLICATED): - scan web_returns - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan customer_demographics + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, ["ca_address_sk"]) - scan customer_address + remote exchange (REPARTITION, HASH, ["wr_item_sk", "wr_order_number"]) + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, ["wr_refunded_addr_sk"]) + scan web_returns + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["ca_address_sk"]) + scan customer_address + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan customer_demographics + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, ["cd_demo_sk_0", "cd_education_status_3", "cd_marital_status_2"]) + scan customer_demographics local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan web_page