diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java index 248d145b390d..d78a457efead 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java @@ -139,8 +139,12 @@ public static boolean predicateMatches( DateTimeZone timeZone) throws IOException { + if (block.getRowCount() == 0) { + return false; + } Map> columnStatistics = getStatistics(block, descriptorsByPath); - Optional> candidateColumns = parquetPredicate.getIndexLookupCandidates(block.getRowCount(), columnStatistics, dataSource.getId()); + Map columnValueCounts = getColumnValueCounts(block, descriptorsByPath); + Optional> candidateColumns = parquetPredicate.getIndexLookupCandidates(columnValueCounts, columnStatistics, dataSource.getId()); if (candidateColumns.isEmpty()) { return false; } @@ -153,7 +157,7 @@ public static boolean predicateMatches( TupleDomainParquetPredicate indexPredicate = new TupleDomainParquetPredicate(parquetTupleDomain, candidateColumns.get(), timeZone); // Page stats is finer grained but relatively more expensive, so we do the filtering after above block filtering. - if (columnIndexStore.isPresent() && !indexPredicate.matches(block.getRowCount(), columnIndexStore.get(), dataSource.getId())) { + if (columnIndexStore.isPresent() && !indexPredicate.matches(columnValueCounts, columnIndexStore.get(), dataSource.getId())) { return false; } @@ -181,6 +185,18 @@ private static Map> getStatistics(BlockMetaData return statistics.buildOrThrow(); } + private static Map getColumnValueCounts(BlockMetaData blockMetadata, Map, ColumnDescriptor> descriptorsByPath) + { + ImmutableMap.Builder columnValueCounts = ImmutableMap.builder(); + for (ColumnChunkMetaData columnMetaData : blockMetadata.getColumns()) { + ColumnDescriptor descriptor = descriptorsByPath.get(Arrays.asList(columnMetaData.getPath().toArray())); + if (descriptor != null) { + columnValueCounts.put(descriptor, columnMetaData.getValueCount()); + } + } + return columnValueCounts.buildOrThrow(); + } + private static boolean dictionaryPredicatesMatch( TupleDomainParquetPredicate parquetPredicate, BlockMetaData blockMetadata, diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java index 05f76fca5bcb..7b1824e42358 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/predicate/TupleDomainParquetPredicate.java @@ -95,7 +95,7 @@ public TupleDomainParquetPredicate(TupleDomain effectivePredic * and if it should, then return the columns are candidates for further inspection of more * granular statistics from column index and dictionary. * - * @param numberOfRows the number of rows in the segment; this can be used with + * @param valueCounts the number of values for a column in the segment; this can be used with * Statistics to determine if a column is only null * @param statistics column statistics * @param id Parquet file name @@ -105,12 +105,12 @@ public TupleDomainParquetPredicate(TupleDomain effectivePredic * to potentially eliminate the file section. An optional with empty list is returned if there is * going to be no benefit in looking at column index or dictionary for any column. */ - public Optional> getIndexLookupCandidates(long numberOfRows, Map> statistics, ParquetDataSourceId id) + public Optional> getIndexLookupCandidates( + Map valueCounts, + Map> statistics, + ParquetDataSourceId id) throws ParquetCorruptionException { - if (numberOfRows == 0) { - return Optional.empty(); - } if (effectivePredicate.isNone()) { return Optional.empty(); } @@ -131,10 +131,14 @@ public Optional> getIndexLookupCandidates(long numberOfRo continue; } + Long columnValueCount = valueCounts.get(column); + if (columnValueCount == null) { + throw new IllegalArgumentException(format("Missing columnValueCount for column %s in %s", column, id)); + } Domain domain = getDomain( column, effectivePredicateDomain.getType(), - numberOfRows, + columnValueCount, columnStatistics, id, timeZone); @@ -174,20 +178,15 @@ public boolean matches(DictionaryDescriptor dictionary) /** * Should the Parquet Reader process a file section with the specified statistics. * - * @param numberOfRows the number of rows in the segment; this can be used with + * @param valueCounts the number of values for a column in the segment; this can be used with * Statistics to determine if a column is only null * @param columnIndexStore column index (statistics) store * @param id Parquet file name */ - public boolean matches(long numberOfRows, ColumnIndexStore columnIndexStore, ParquetDataSourceId id) + public boolean matches(Map valueCounts, ColumnIndexStore columnIndexStore, ParquetDataSourceId id) throws ParquetCorruptionException { requireNonNull(columnIndexStore, "columnIndexStore is null"); - - if (numberOfRows == 0) { - return false; - } - if (effectivePredicate.isNone()) { return false; } @@ -206,7 +205,11 @@ public boolean matches(long numberOfRows, ColumnIndexStore columnIndexStore, Par continue; } - Domain domain = getDomain(effectivePredicateDomain.getType(), numberOfRows, columnIndex, id, column, timeZone); + Long columnValueCount = valueCounts.get(column); + if (columnValueCount == null) { + throw new IllegalArgumentException(format("Missing columnValueCount for column %s in %s", column, id)); + } + Domain domain = getDomain(effectivePredicateDomain.getType(), columnValueCount, columnIndex, id, column, timeZone); if (!effectivePredicateDomain.overlaps(domain)) { return false; } @@ -235,7 +238,7 @@ private boolean effectivePredicateMatches(Domain effectivePredicateDomain, Dicti public static Domain getDomain( ColumnDescriptor column, Type type, - long rowCount, + long columnValuesCount, Statistics statistics, ParquetDataSourceId id, DateTimeZone timeZone) @@ -245,7 +248,7 @@ public static Domain getDomain( return Domain.all(type); } - if (statistics.isNumNullsSet() && statistics.getNumNulls() == rowCount) { + if (statistics.isNumNullsSet() && statistics.getNumNulls() == columnValuesCount) { return Domain.onlyNull(type); } @@ -437,7 +440,7 @@ private static Domain getDomain( @VisibleForTesting public static Domain getDomain( Type type, - long rowCount, + long columnValuesCount, ColumnIndex columnIndex, ParquetDataSourceId id, ColumnDescriptor descriptor, @@ -466,7 +469,7 @@ public static Domain getDomain( .sum(); boolean hasNullValue = totalNullCount > 0; - if (hasNullValue && totalNullCount == rowCount) { + if (hasNullValue && totalNullCount == columnValuesCount) { return Domain.onlyNull(type); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/TestTupleDomainParquetPredicate.java b/lib/trino-parquet/src/test/java/io/trino/parquet/TestTupleDomainParquetPredicate.java index 01a8f2bae0b9..627676c06304 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/TestTupleDomainParquetPredicate.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/TestTupleDomainParquetPredicate.java @@ -154,7 +154,8 @@ public void testBigint() assertEquals(getDomain(columnDescriptor, BIGINT, 10, longColumnStats(0L, 100L), ID, UTC), create(ValueSet.ofRanges(range(BIGINT, 0L, true, 100L, true)), false)); - assertEquals(getDomain(columnDescriptor, BIGINT, 20, longOnlyNullsStats(10), ID, UTC), create(ValueSet.all(BIGINT), true)); + assertEquals(getDomain(columnDescriptor, BIGINT, 20, longOnlyNullsStats(10), ID, UTC), Domain.all(BIGINT)); + assertEquals(getDomain(columnDescriptor, BIGINT, 20, longOnlyNullsStats(20), ID, UTC), Domain.onlyNull(BIGINT)); // fail on corrupted statistics assertThatExceptionOfType(ParquetCorruptionException.class) .isThrownBy(() -> getDomain(columnDescriptor, BIGINT, 10, longColumnStats(100L, 10L), ID, UTC)) @@ -555,7 +556,7 @@ public void testVarcharMatchesWithStatistics() .withMax(value.getBytes(UTF_8)) .withNumNulls(1L) .build(); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, stats), ID)) + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, stats), ID)) .isEqualTo(Optional.of(ImmutableList.of(column))); } @@ -569,10 +570,10 @@ public void testIntegerMatchesWithStatistics(Type typeForParquetInt32) Domain.create(ValueSet.of(typeForParquetInt32, 42L, 43L, 44L, 112L), false))); TupleDomainParquetPredicate parquetPredicate = new TupleDomainParquetPredicate(effectivePredicate, singletonList(column), UTC); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, intColumnStats(32, 42)), ID)) + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, intColumnStats(32, 42)), ID)) .isEqualTo(Optional.of(ImmutableList.of(column))); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, intColumnStats(30, 40)), ID)).isEmpty(); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, intColumnStats(1024, 0x10000 + 42)), ID).isPresent()) + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, intColumnStats(30, 40)), ID)).isEmpty(); + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, intColumnStats(1024, 0x10000 + 42)), ID).isPresent()) .isEqualTo(typeForParquetInt32 != INTEGER); // stats invalid for smallint/tinyint } @@ -596,10 +597,10 @@ public void testBigintMatchesWithStatistics() Domain.create(ValueSet.of(BIGINT, 42L, 43L, 44L, 404L), false))); TupleDomainParquetPredicate parquetPredicate = new TupleDomainParquetPredicate(effectivePredicate, singletonList(column), UTC); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, longColumnStats(32, 42)), ID)) + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, longColumnStats(32, 42)), ID)) .isEqualTo(Optional.of(ImmutableList.of(column))); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, longColumnStats(30, 40)), ID)).isEmpty(); - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(column, longColumnStats(1024, 0x10000 + 42)), ID)).isEmpty(); + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, longColumnStats(30, 40)), ID)).isEmpty(); + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(column, 2L), ImmutableMap.of(column, longColumnStats(1024, 0x10000 + 42)), ID)).isEmpty(); } @Test @@ -669,23 +670,23 @@ public void testIndexLookupCandidates() TupleDomainParquetPredicate parquetPredicate = new TupleDomainParquetPredicate(effectivePredicate, singletonList(columnA), UTC); assertThat(parquetPredicate.getIndexLookupCandidates( - 2, + ImmutableMap.of(columnA, 2L, columnB, 2L), ImmutableMap.of(columnA, longColumnStats(32, 42), columnB, longColumnStats(42, 500)), ID)) .isEqualTo(Optional.of(ImmutableList.of(columnA))); parquetPredicate = new TupleDomainParquetPredicate(effectivePredicate, ImmutableList.of(columnA, columnB), UTC); // column stats missing on columnB - assertThat(parquetPredicate.getIndexLookupCandidates(2, ImmutableMap.of(columnA, longColumnStats(32, 42)), ID)) + assertThat(parquetPredicate.getIndexLookupCandidates(ImmutableMap.of(columnA, 2L), ImmutableMap.of(columnA, longColumnStats(32, 42)), ID)) .isEqualTo(Optional.of(ImmutableList.of(columnA, columnB))); // All possible values for columnB are covered by effectivePredicate assertThat(parquetPredicate.getIndexLookupCandidates( - 2, + ImmutableMap.of(columnA, 2L, columnB, 2L), ImmutableMap.of(columnA, longColumnStats(32, 42), columnB, longColumnStats(50, 400)), ID)) .isEqualTo(Optional.of(ImmutableList.of(columnA))); assertThat(parquetPredicate.getIndexLookupCandidates( - 2, + ImmutableMap.of(columnA, 2L, columnB, 2L), ImmutableMap.of(columnA, longColumnStats(32, 42), columnB, longColumnStats(42, 500)), ID)) .isEqualTo(Optional.of(ImmutableList.of(columnA, columnB))); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index 1223edc6bdf7..845f05094d40 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -5187,6 +5187,32 @@ private void testParquetDictionaryPredicatePushdown(Session session) assertNoDataRead("SELECT * FROM " + tableName + " WHERE n = 3"); } + @Test + public void testParquetOnlyNullsRowGroupPruning() + { + String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL"); + + tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); + // Nested column `a` has nulls count of 4096 and contains only nulls + // Nested column `b` also has nulls count of 4096, but it contains non nulls as well + assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096); + // TODO replace with assertNoDataRead after nested column predicate pushdown + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + } + private void assertNoDataRead(@Language("SQL") String sql) { assertQueryStats(