Skip to content

Commit

Permalink
Consider correlation for output estimation of filter conjuncts
Browse files Browse the repository at this point in the history
Currently we assume that there is no correlation between
the terms of a filter conjunction. This can result in underestimation
as there is often some correlation between columns in real data sets.
In particular, predicates inferred on the build side relation through
a join with a partitioned table are often correlated with user provided
predicates on the build side.
Estimation for filter conjunctions now applies an exponential decay
to the selectivity of each successive term to reduce chances of
under estimation.
optimizer.filter-conjunction-independence-factor is added to allow
tuning the strength of the independence assumption.
  • Loading branch information
raunaqmorarka authored and sopel39 committed Mar 11, 2022
1 parent 291b65d commit 67d2df9
Show file tree
Hide file tree
Showing 14 changed files with 472 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static io.trino.spi.session.PropertyMetadata.integerProperty;
import static io.trino.spi.session.PropertyMetadata.longProperty;
import static io.trino.spi.session.PropertyMetadata.stringProperty;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey;
import static java.lang.Math.min;
Expand Down Expand Up @@ -122,6 +123,7 @@ public final class SystemSessionProperties
public static final String IGNORE_STATS_CALCULATOR_FAILURES = "ignore_stats_calculator_failures";
public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task";
public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled";
public static final String FILTER_CONJUNCTION_INDEPENDENCE_FACTOR = "filter_conjunction_independence_factor";
public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort";
public static final String ALLOW_PUSHDOWN_INTO_CONNECTORS = "allow_pushdown_into_connectors";
public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown";
Expand Down Expand Up @@ -557,6 +559,15 @@ public SystemSessionProperties(
"use a default filter factor for unknown filters in a filter node",
optimizerConfig.isDefaultFilterFactorEnabled(),
false),
new PropertyMetadata<>(
FILTER_CONJUNCTION_INDEPENDENCE_FACTOR,
"Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters",
DOUBLE,
Double.class,
optimizerConfig.getFilterConjunctionIndependenceFactor(),
false,
value -> validateDoubleRange(value, FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, 0.0, 1.0),
value -> value),
booleanProperty(
SKIP_REDUNDANT_SORT,
"Skip redundant sort operations",
Expand Down Expand Up @@ -1134,6 +1145,17 @@ private static Integer validateIntegerValue(Object value, String property, int l
return intValue;
}

private static double validateDoubleRange(Object value, String property, double lowerBoundIncluded, double upperBoundIncluded)
{
double doubleValue = (double) value;
if (doubleValue < lowerBoundIncluded || doubleValue > upperBoundIncluded) {
throw new TrinoException(
INVALID_SESSION_PROPERTY,
format("%s must be in the range [%.2f, %.2f]: %.2f", property, lowerBoundIncluded, upperBoundIncluded, doubleValue));
}
return doubleValue;
}

public static boolean isStatisticsCpuTimerEnabled(Session session)
{
return session.getSystemProperty(STATISTICS_CPU_TIMER_ENABLED, Boolean.class);
Expand Down Expand Up @@ -1164,6 +1186,11 @@ public static boolean isDefaultFilterFactorEnabled(Session session)
return session.getSystemProperty(DEFAULT_FILTER_FACTOR_ENABLED, Boolean.class);
}

public static double getFilterConjunctionIndependenceFactor(Session session)
{
return session.getSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, Double.class);
}

public static boolean isSkipRedundantSort(Session session)
{
return session.getSystemProperty(SKIP_REDUNDANT_SORT, Boolean.class);
Expand Down
136 changes: 111 additions & 25 deletions core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
package io.trino.cost;

import com.google.common.base.VerifyException;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.security.AllowAllAccessControl;
Expand Down Expand Up @@ -44,30 +46,37 @@
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.util.DisjointSet;

import javax.annotation.Nullable;
import javax.inject.Inject;

import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getFilterConjunctionIndependenceFactor;
import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison;
import static io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison;
import static io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues;
import static io.trino.cost.PlanNodeStatsEstimateMath.capStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount;
import static io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats;
import static io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats;
import static io.trino.spi.statistics.StatsUtil.toStatsRepresentation;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.DynamicFilters.isDynamicFilter;
import static io.trino.sql.ExpressionUtils.and;
import static io.trino.sql.ExpressionUtils.getExpressionTypes;
import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
Expand Down Expand Up @@ -137,7 +146,14 @@ private class FilterExpressionStatsCalculatingVisitor
@Override
public PlanNodeStatsEstimate process(Node node, @Nullable Void context)
{
return normalizer.normalize(super.process(node, context), types);
PlanNodeStatsEstimate output;
if (input.getOutputRowCount() == 0 || input.isOutputRowCountUnknown()) {
output = input;
}
else {
output = super.process(node, context);
}
return normalizer.normalize(output, types);
}

@Override
Expand Down Expand Up @@ -169,35 +185,56 @@ protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, V

private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> terms)
{
// first try to estimate in the fair way
PlanNodeStatsEstimate estimate = process(terms.get(0));
if (!estimate.isOutputRowCountUnknown()) {
for (int i = 1; i < terms.size(); i++) {
estimate = new FilterExpressionStatsCalculatingVisitor(estimate, session, types).process(terms.get(i));
double filterConjunctionIndependenceFactor = getFilterConjunctionIndependenceFactor(session);
List<PlanNodeStatsEstimate> estimates = estimateCorrelatedExpressions(terms, filterConjunctionIndependenceFactor);
double outputRowCount = estimateCorrelatedConjunctionRowCount(
input,
estimates,
filterConjunctionIndependenceFactor);
if (isNaN(outputRowCount)) {
return PlanNodeStatsEstimate.unknown();
}
return normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, intersectCorrelatedStats(estimates)), types);
}

if (estimate.isOutputRowCountUnknown()) {
break;
/**
* There can be multiple predicate expressions for the same symbol, e.g. x > 0 AND x <= 1, x BETWEEN 1 AND 10.
* We attempt to detect such cases in extractCorrelatedGroups and calculate a combined estimate for each
* such group of expressions. This is done so that we don't apply the above scaling factors when combining estimates
* from conjunction of multiple predicates on the same symbol and underestimate the output.
**/
private List<PlanNodeStatsEstimate> estimateCorrelatedExpressions(List<Expression> terms, double filterConjunctionIndependenceFactor)
{
ImmutableList.Builder<PlanNodeStatsEstimate> estimatesBuilder = ImmutableList.builder();
boolean hasUnestimatedTerm = false;
for (List<Expression> correlatedExpressions : extractCorrelatedGroups(terms, filterConjunctionIndependenceFactor)) {
PlanNodeStatsEstimate combinedEstimate = PlanNodeStatsEstimate.unknown();
for (Expression expression : correlatedExpressions) {
PlanNodeStatsEstimate estimate;
// combinedEstimate is unknown until the 1st known estimated term
if (combinedEstimate.isOutputRowCountUnknown()) {
estimate = process(expression);
}
else {
estimate = new FilterExpressionStatsCalculatingVisitor(combinedEstimate, session, types)
.process(expression);
}
}

if (!estimate.isOutputRowCountUnknown()) {
return estimate;
if (estimate.isOutputRowCountUnknown()) {
hasUnestimatedTerm = true;
}
else {
// update combinedEstimate only when the term estimate is known so that all the known estimates
// can be applied progressively through FilterExpressionStatsCalculatingVisitor calls.
combinedEstimate = estimate;
}
}
estimatesBuilder.add(combinedEstimate);
}

// If some of the filters cannot be estimated, take the smallest estimate.
// Apply 0.9 filter factor as "unknown filter" factor.
Optional<PlanNodeStatsEstimate> smallest = terms.stream()
.map(this::process)
.filter(termEstimate -> !termEstimate.isOutputRowCountUnknown())
.sorted(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount))
.findFirst();

if (smallest.isEmpty()) {
return PlanNodeStatsEstimate.unknown();
if (hasUnestimatedTerm) {
estimatesBuilder.add(PlanNodeStatsEstimate.unknown());
}

return smallest.get().mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT);
return estimatesBuilder.build();
}

private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms)
Expand Down Expand Up @@ -442,4 +479,53 @@ private OptionalDouble doubleValueFromLiteral(Type type, Expression literal)
return toStatsRepresentation(type, literalValue);
}
}

private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor)
{
if (filterConjunctionIndependenceFactor == 1) {
// Allows the filters to be estimated as if there is no correlation between any of the terms
return ImmutableList.of(terms);
}

ListMultimap<Expression, Symbol> expressionUniqueSymbols = ArrayListMultimap.create();
terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression)));
// Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets
// do not appear together in any expression.
DisjointSet<Symbol> symbolsPartitioner = new DisjointSet<>();
for (Expression term : terms) {
List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
if (expressionSymbols.isEmpty()) {
continue;
}
// Ensure that symbol is added to DisjointSet when there is only one symbol in the list
symbolsPartitioner.find(expressionSymbols.get(0));
for (int i = 1; i < expressionSymbols.size(); i++) {
symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i));
}
}

// Use disjoint sets of symbols to partition the given list of expressions
List<Set<Symbol>> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses());
checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions");
ListMultimap<Integer, Expression> expressionPartitions = ArrayListMultimap.create();
for (Expression term : terms) {
List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
int expressionPartitionId;
if (expressionSymbols.isEmpty()) {
expressionPartitionId = symbolPartitions.size(); // For expressions with no symbols
}
else {
Symbol symbol = expressionSymbols.get(0); // Lookup any symbol to find the partition id
expressionPartitionId = IntStream.range(0, symbolPartitions.size())
.filter(partition -> symbolPartitions.get(partition).contains(symbol))
.findFirst()
.orElseThrow();
}
expressionPartitions.put(expressionPartitionId, term);
}

return expressionPartitions.keySet().stream()
.map(expressionPartitions::get)
.collect(toImmutableList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@
*/
package io.trino.cost;

import io.trino.sql.planner.Symbol;
import io.trino.util.MoreMath;

import java.util.List;
import java.util.Map;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.lang.Double.max;
import static java.lang.Double.min;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Stream.concat;

public final class PlanNodeStatsEstimateMath
Expand Down Expand Up @@ -135,6 +144,83 @@ public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNo
return result.build();
}

public static Map<Symbol, SymbolStatsEstimate> intersectCorrelatedStats(List<PlanNodeStatsEstimate> estimates)
{
checkArgument(!estimates.isEmpty(), "estimates is empty");
if (estimates.size() == 1) {
return estimates.get(0).getSymbolStatistics();
}
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
// Update statistic range for symbols
estimates.stream().flatMap(estimate -> estimate.getSymbolsWithKnownStatistics().stream())
.distinct()
.forEach(symbol -> {
List<SymbolStatsEstimate> symbolStatsEstimates = estimates.stream()
.map(estimate -> estimate.getSymbolStatistics(symbol))
.collect(toImmutableList());

StatisticRange intersect = symbolStatsEstimates.stream()
.map(StatisticRange::from)
.reduce(StatisticRange::intersect)
.orElseThrow();

// intersectCorrelatedStats should try to produce stats as if filters are applied in sequence.
// Using min works for filters like (a > 10 AND b < 10), but won't work for
// (a > 10 AND b IS NULL). However, former case is more common.
double nullsFraction = symbolStatsEstimates.stream()
.map(SymbolStatsEstimate::getNullsFraction)
.reduce(MoreMath::minExcludeNaN)
.orElseThrow();

double averageRowSize = symbolStatsEstimates.stream()
.map(SymbolStatsEstimate::getAverageRowSize)
.reduce(MoreMath::averageExcludingNaNs)
.orElseThrow();

result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder()
.setStatisticsRange(intersect)
.setNullsFraction(nullsFraction)
.setAverageRowSize(averageRowSize)
.build());
});
return result.build().getSymbolStatistics();
}

public static double estimateCorrelatedConjunctionRowCount(
PlanNodeStatsEstimate input,
List<PlanNodeStatsEstimate> estimates,
double independenceFactor)
{
checkArgument(!estimates.isEmpty(), "estimates is empty");
if (input.isOutputRowCountUnknown() || input.getOutputRowCount() == 0) {
return input.getOutputRowCount();
}
List<PlanNodeStatsEstimate> knownSortedEstimates = estimates.stream()
.filter(estimateInfo -> !estimateInfo.isOutputRowCountUnknown())
.sorted(comparingDouble(PlanNodeStatsEstimate::getOutputRowCount))
.collect(toImmutableList());
if (knownSortedEstimates.isEmpty()) {
return NaN;
}

PlanNodeStatsEstimate combinedEstimate = knownSortedEstimates.get(0);
double combinedSelectivity = combinedEstimate.getOutputRowCount() / input.getOutputRowCount();
double combinedIndependenceFactor = 1.0;
// For independenceFactor = 0.75 and terms t1, t2, t3
// Combined selectivity = (t1 selectivity) * ((t2 selectivity) ^ 0.75) * ((t3 selectivity) ^ (0.75 * 0.75))
// independenceFactor = 1 implies the terms are assumed to have no correlation and their selectivities are multiplied without scaling.
// independenceFactor = 0 implies the terms are assumed to be fully correlated and only the most selective term drives the selectivity.
for (int i = 1; i < knownSortedEstimates.size(); i++) {
PlanNodeStatsEstimate term = knownSortedEstimates.get(i);
combinedIndependenceFactor *= independenceFactor;
combinedSelectivity *= Math.pow(term.getOutputRowCount() / input.getOutputRowCount(), combinedIndependenceFactor);
}
double outputRowCount = input.getOutputRowCount() * combinedSelectivity;
// TODO use UNKNOWN_FILTER_COEFFICIENT only when default-filter-factor is enabled
boolean hasUnestimatedTerm = estimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown);
return hasUnestimatedTerm ? outputRowCount * UNKNOWN_FILTER_COEFFICIENT : outputRowCount;
}

private static PlanNodeStatsEstimate createZeroStats(PlanNodeStatsEstimate stats)
{
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
Expand Down
Loading

0 comments on commit 67d2df9

Please sign in to comment.