Skip to content

Commit

Permalink
Populate a nulls fraction when NDV is known in TableScanStatsRule
Browse files Browse the repository at this point in the history
If a connector provides NDV but is missing nulls fraction statistic for a column
(e.g. Delta Lake after "delta.dataSkippingNumIndexedCols" columns and MySql), populate a
10% guess value so that the CBO can still produce some estimates rather than
failing to make any estimates due to lack of nulls fraction.
  • Loading branch information
raunaqmorarka committed Nov 23, 2022
1 parent b362da8 commit c2666d7
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.matching.Pattern;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.FixedWidthType;
import io.trino.spi.type.Type;
Expand All @@ -37,6 +38,7 @@
public class TableScanStatsRule
extends SimpleStatsRule<TableScanNode>
{
private static final double UNKNOWN_NULLS_FRACTION = 0.1;
private static final Pattern<TableScanNode> PATTERN = tableScan();

public TableScanStatsRule(StatsNormalizer normalizer)
Expand Down Expand Up @@ -82,7 +84,7 @@ private static SymbolStatsEstimate toSymbolStatistics(TableStatistics tableStati
requireNonNull(columnStatistics, "columnStatistics is null");
requireNonNull(type, "type is null");

double nullsFraction = columnStatistics.getNullsFraction().getValue();
double nullsFraction = getNullsFraction(columnStatistics, tableStatistics.getRowCount());
double nonNullRowsCount = tableStatistics.getRowCount().getValue() * (1.0 - nullsFraction);
double averageRowSize;
if (nonNullRowsCount == 0) {
Expand All @@ -105,4 +107,25 @@ else if (type instanceof FixedWidthType) {
});
return result.build();
}

private static double getNullsFraction(ColumnStatistics columnStatistics, Estimate rowCount)
{
if (!columnStatistics.getNullsFraction().isUnknown()
|| columnStatistics.getDistinctValuesCount().isUnknown()
|| rowCount.isUnknown()) {
return columnStatistics.getNullsFraction().getValue();
}
// When NDV is greater than or equal to row count, there are no nulls
if (columnStatistics.getDistinctValuesCount().getValue() >= rowCount.getValue()) {
return 0;
}

double maxPossibleNulls = rowCount.getValue() - columnStatistics.getDistinctValuesCount().getValue();

// If a connector provides NDV but is missing nulls fraction statistic for a column
// (e.g. Delta Lake after "delta.dataSkippingNumIndexedCols" columns and MySql), populate a
// 10% guess value so that the CBO can still produce some estimates rather failing to make
// any estimates due to lack of nulls fraction.
return Math.min(UNKNOWN_NULLS_FRACTION, maxPossibleNulls / rowCount.getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class StatsCalculatorAssertion

private final Map<PlanNode, PlanNodeStatsEstimate> sourcesStats;

private Optional<TableStatsProvider> tableStatsProvider = Optional.empty();

public StatsCalculatorAssertion(Metadata metadata, StatsCalculator statsCalculator, Session session, PlanNode planNode, TypeProvider types)
{
this.metadata = requireNonNull(metadata, "metadata cannot be null");
Expand Down Expand Up @@ -83,19 +85,38 @@ public StatsCalculatorAssertion withSourceStats(Map<PlanNode, PlanNodeStatsEstim
return this;
}

public StatsCalculatorAssertion withTableStatisticsProvider(TableStatsProvider tableStatsProvider)
{
this.tableStatsProvider = Optional.of(tableStatsProvider);
return this;
}

public StatsCalculatorAssertion check(Consumer<PlanNodeStatsAssertion> statisticsAssertionConsumer)
{
PlanNodeStatsEstimate statsEstimate = transaction(new TestingTransactionManager(), new AllowAllAccessControl())
.execute(session, transactionSession -> {
return statsCalculator.calculateStats(planNode, this::getSourceStats, noLookup(), transactionSession, types, new CachingTableStatsProvider(metadata, session));
return statsCalculator.calculateStats(
planNode,
this::getSourceStats,
noLookup(),
transactionSession,
types,
tableStatsProvider.orElse(new CachingTableStatsProvider(metadata, session)));
});
statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate));
return this;
}

public StatsCalculatorAssertion check(Rule<?> rule, Consumer<PlanNodeStatsAssertion> statisticsAssertionConsumer)
{
Optional<PlanNodeStatsEstimate> statsEstimate = calculatedStats(rule, planNode, this::getSourceStats, noLookup(), session, types, new CachingTableStatsProvider(metadata, session));
Optional<PlanNodeStatsEstimate> statsEstimate = calculatedStats(
rule,
planNode,
this::getSourceStats,
noLookup(),
session,
types,
tableStatsProvider.orElse(new CachingTableStatsProvider(metadata, session)));
checkState(statsEstimate.isPresent(), "Expected stats estimates to be present");
statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate.get()));
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cost;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.TestingColumnHandle;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.DoubleRange;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
import io.trino.sql.planner.Symbol;
import org.testng.annotations.Test;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;

public class TestTableScanStatsRule
extends BaseStatsCalculatorTest
{
@Test
public void testStatsForTableScan()
{
ColumnHandle columnA = new TestingColumnHandle("a");
ColumnHandle columnB = new TestingColumnHandle("b");
ColumnHandle columnC = new TestingColumnHandle("c");
ColumnHandle columnD = new TestingColumnHandle("d");
ColumnHandle columnE = new TestingColumnHandle("e");
ColumnHandle unknownColumn = new TestingColumnHandle("unknown");
tester()
.assertStatsFor(pb -> {
Symbol a = pb.symbol("a", BIGINT);
Symbol b = pb.symbol("b", DOUBLE);
Symbol c = pb.symbol("c", DOUBLE);
Symbol d = pb.symbol("d", DOUBLE);
Symbol e = pb.symbol("e", INTEGER);
Symbol unknown = pb.symbol("unknown", INTEGER);
return pb.tableScan(
ImmutableList.of(a, b, c, d, e, unknown),
ImmutableMap.of(a, columnA, b, columnB, c, columnC, d, columnD, e, columnE, unknown, unknownColumn));
})
.withTableStatisticsProvider(tableHandle -> TableStatistics.builder()
.setRowCount(Estimate.of(33))
.setColumnStatistics(
columnA,
ColumnStatistics.builder().setDistinctValuesCount(Estimate.of(20)).build())
.setColumnStatistics(
columnB,
ColumnStatistics.builder().setNullsFraction(Estimate.of(0.3)).setDistinctValuesCount(Estimate.of(23.1)).build())
.setColumnStatistics(
columnC,
ColumnStatistics.builder().setRange(new DoubleRange(15, 20)).build())
.setColumnStatistics(
columnD,
ColumnStatistics.builder().setDistinctValuesCount(Estimate.of(33)).build())
.setColumnStatistics(
columnE,
ColumnStatistics.builder().setDistinctValuesCount(Estimate.of(31)).build())
.setColumnStatistics(unknownColumn, ColumnStatistics.empty())
.build())
.check(check -> check
.outputRowsCount(33)
.symbolStats("a", assertion -> assertion
.distinctValuesCount(20)
// UNKNOWN_NULLS_FRACTION populated to allow CBO to use NDV for estimation
.nullsFraction(0.1))
.symbolStats("b", assertion -> assertion
.distinctValuesCount(23.1)
.nullsFraction(0.3))
.symbolStats("c", assertion -> assertion
.distinctValuesCountUnknown()
.nullsFractionUnknown()
.lowValue(15)
.highValue(20))
.symbolStats("d", assertion -> assertion
.distinctValuesCount(33)
.nullsFraction(0))
.symbolStats("e", assertion -> assertion
.distinctValuesCount(31)
.nullsFraction(0.06060606))
.symbolStats("unknown", assertion -> assertion
.unknownRange()
.distinctValuesCountUnknown()
.nullsFractionUnknown()
.dataSizeUnknown()));
}

@Test
public void testZeroStatsForTableScan()
{
ColumnHandle columnHandle = new TestingColumnHandle("zero");
tester()
.assertStatsFor(pb -> {
Symbol column = pb.symbol("zero", INTEGER);
return pb.tableScan(
ImmutableList.of(column),
ImmutableMap.of(column, columnHandle));
})
.withTableStatisticsProvider(tableHandle -> TableStatistics.builder()
.setRowCount(Estimate.zero())
.setColumnStatistics(
columnHandle,
ColumnStatistics.builder().setDistinctValuesCount(Estimate.zero()).build())
.build())
.check(check -> check
.outputRowsCount(0)
.symbolStats("zero", assertion -> assertion.isEqualTo(SymbolStatsEstimate.zero())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3128,9 +3128,9 @@ public void testBasicAnalyze()

String statsWithNdv = format == AVRO
? ("VALUES " +
" ('regionkey', NULL, 5e0, NULL, NULL, NULL, NULL), " +
" ('name', NULL, 5e0, NULL, NULL, NULL, NULL), " +
" ('comment', NULL, 5e0, NULL, NULL, NULL, NULL), " +
" ('regionkey', NULL, 5e0, 0e0, NULL, NULL, NULL), " +
" ('name', NULL, 5e0, 0e0, NULL, NULL, NULL), " +
" ('comment', NULL, 5e0, 0e0, NULL, NULL, NULL), " +
" (NULL, NULL, NULL, NULL, 5e0, NULL, NULL)")
: ("VALUES " +
" ('regionkey', NULL, 5e0, 0e0, NULL, '0', '4'), " +
Expand Down Expand Up @@ -4019,20 +4019,20 @@ public void testAllAvailableTypes()
assertThat(query(extendedStatisticsEnabled, "SHOW STATS FOR test_all_types"))
.skippingTypesCheck()
.matches("VALUES " +
" ('a_boolean', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('an_integer', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_bigint', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_real', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_double', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_short_decimal', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_long_decimal', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_varchar', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_varbinary', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_date', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_time', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_timestamp', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_timestamptz', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_uuid', NULL, 1e0, NULL, NULL, NULL, NULL), " +
" ('a_boolean', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('an_integer', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_bigint', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_real', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_double', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_short_decimal', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_long_decimal', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_varchar', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_varbinary', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_date', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_time', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_timestamp', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_timestamptz', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_uuid', NULL, 1e0, 0.1e0, NULL, NULL, NULL), " +
" ('a_row', NULL, NULL, NULL, NULL, NULL, NULL), " +
" ('an_array', NULL, NULL, NULL, NULL, NULL, NULL), " +
" ('a_map', NULL, NULL, NULL, NULL, NULL, NULL), " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ public class MySqlClient

private static final JsonCodec<ColumnHistogram> HISTOGRAM_CODEC = jsonCodec(ColumnHistogram.class);

// We don't know null fraction, but having no null fraction will make CBO useless. Assume some arbitrary value.
private static final Estimate UNKNOWN_NULL_FRACTION_REPLACEMENT = Estimate.of(0.1);

private final Type jsonType;
private final boolean statisticsEnabled;
private final ConnectorExpressionRewriter<String> connectorExpressionRewriter;
Expand Down Expand Up @@ -798,13 +795,7 @@ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableH
rowCount = max(rowCount, columnIndexStatistics.getCardinality());
}

ColumnStatistics columnStatistics = columnStatisticsBuilder.build();
if (!columnStatistics.getDistinctValuesCount().isUnknown() && columnStatistics.getNullsFraction().isUnknown()) {
columnStatisticsBuilder.setNullsFraction(UNKNOWN_NULL_FRACTION_REPLACEMENT);
columnStatistics = columnStatisticsBuilder.build();
}

tableStatistics.setColumnStatistics(column, columnStatistics);
tableStatistics.setColumnStatistics(column, columnStatisticsBuilder.build());
}

tableStatistics.setRowCount(Estimate.of(rowCount));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ public void testNumericCornerCases()
// "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," +
// "('nans_only', null, 1.0, 0.5, null, null, null)," +
// "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," +
"('large_doubles', null, 1.9, 0.050000000000000044, null, null, null)," +
"('short_decimals_big_fraction', null, 1.9, 0.050000000000000044, null, null, null)," +
"('short_decimals_big_integral', null, 1.9, 0.050000000000000044, null, null, null)," +
"('long_decimals_big_fraction', null, 1.9, 0.050000000000000044, null, null, null)," +
"('long_decimals_middle', null, 1.9, 0.050000000000000044, null, null, null)," +
"('long_decimals_big_integral', null, 1.9, 0.050000000000000044, null, null, null)," +
"('large_doubles', null, 2.0, 0.0, null, null, null)," +
"('short_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," +
"('short_decimals_big_integral', null, 2.0, 0.0, null, null, null)," +
"('long_decimals_big_fraction', null, 2.0, 0.0, null, null, null)," +
"('long_decimals_middle', null, 2.0, 0.0, null, null, null)," +
"('long_decimals_big_integral', null, 2.0, 0.0, null, null, null)," +
"(null, null, null, null, 2, null, null)");
}
}
Expand Down Expand Up @@ -394,15 +394,24 @@ protected void assertColumnStats(MaterializedResult statsResult, Map<String, Int
.isEqualTo(0);
}

AbstractDoubleAssert<?> ndvAssertion = assertThat((Double) row.getField(2)).as("NDV for " + columnName);
Double distinctCount = (Double) row.getField(2);
Double nullsFraction = (Double) row.getField(3);
AbstractDoubleAssert<?> ndvAssertion = assertThat(distinctCount).as("NDV for " + columnName);
if (expectedNdv == null) {
ndvAssertion.isNull();
assertNull(row.getField(3), "null fraction for " + columnName);
assertNull(nullsFraction, "null fraction for " + columnName);
}
else {
ndvAssertion.isBetween(expectedNdv * 0.5, min(expectedNdv * 4.0, tableCardinality)); // [-50%, +300%] but no more than row count
assertThat((Double) row.getField(3)).as("Null fraction for " + columnName)
.isBetween(expectedNullFraction * 0.4, min(expectedNullFraction * 1.1, 1.0));
AbstractDoubleAssert<?> nullsAssertion = assertThat(nullsFraction).as("Null fraction for " + columnName);
if (distinctCount.compareTo(tableCardinality) >= 0) {
nullsAssertion.isEqualTo(0);
}
else {
double maxNullsFraction = (tableCardinality - distinctCount) / tableCardinality;
expectedNullFraction = Math.min(expectedNullFraction, maxNullsFraction);
nullsAssertion.isBetween(expectedNullFraction * 0.4, expectedNullFraction * 1.1);
}
}

assertNull(row.getField(4), "min");
Expand Down

0 comments on commit c2666d7

Please sign in to comment.