From d6f2da118641364cd23ef87e2262617dbbfee68f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 21 Jun 2022 14:13:23 +0200 Subject: [PATCH] Unwrap timestamptz cast to date in Iceberg --- .../plugin/iceberg/ConstraintExtractor.java | 210 ++++++++++++++ .../trino/plugin/iceberg/IcebergMetadata.java | 65 +++-- .../iceberg/BaseIcebergConnectorTest.java | 66 +++++ .../iceberg/TestConstraintExtractor.java | 258 ++++++++++++++++++ 4 files changed, 573 insertions(+), 26 deletions(-) create mode 100644 plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstraintExtractor.java create mode 100644 plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstraintExtractor.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstraintExtractor.java new file mode 100644 index 000000000000..2c9afb28b10a --- /dev/null +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstraintExtractor.java @@ -0,0 +1,210 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.Constraint; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Constant; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.DateType; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.Type; + +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.plugin.base.expression.ConnectorExpressions.and; +import static io.trino.plugin.base.expression.ConnectorExpressions.extractConjuncts; +import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_DAY; +import static java.util.Objects.requireNonNull; + +public final class ConstraintExtractor +{ + private ConstraintExtractor() {} + + public static ExtractionResult extractTupleDomain(Constraint constraint) + { + TupleDomain result = constraint.getSummary() + .transformKeys(IcebergColumnHandle.class::cast); + ImmutableList.Builder remainingExpressions = ImmutableList.builder(); + for (ConnectorExpression conjunct : extractConjuncts(constraint.getExpression())) { + Optional> converted = toTupleDomain(conjunct, constraint.getAssignments()); + if (converted.isEmpty()) { + remainingExpressions.add(conjunct); + } + else { + result = result.intersect(converted.get()); + if (result.isNone()) { + return new ExtractionResult(TupleDomain.none(), Constant.TRUE); + } + } + } + return new ExtractionResult(result, and(remainingExpressions.build())); + } + + private static Optional> toTupleDomain(ConnectorExpression expression, Map assignments) + { + if (expression instanceof Call) { + return toTupleDomain((Call) expression, assignments); + } + return Optional.empty(); + } + + private static Optional> toTupleDomain(Call call, Map assignments) + { + if (call.getArguments().size() == 2) { + ConnectorExpression firstArgument = call.getArguments().get(0); + ConnectorExpression secondArgument = call.getArguments().get(1); + + // Note: CanonicalizeExpressionRewriter ensures that constants are the second comparison argument. + + if (firstArgument instanceof Call && ((Call) firstArgument).getFunctionName().equals(CAST_FUNCTION_NAME) && + secondArgument instanceof Constant && + // if type do no match, this cannot be a comparison function + firstArgument.getType().equals(secondArgument.getType())) { + return unwrapCastInComparison( + call.getFunctionName(), + getOnlyElement(((Call) firstArgument).getArguments()), + (Constant) secondArgument, + assignments); + } + } + + return Optional.empty(); + } + + private static Optional> unwrapCastInComparison( + // upon invocation, we don't know if this really is a comparison + FunctionName functionName, + ConnectorExpression castSource, + Constant constant, + Map assignments) + { + if (!(castSource instanceof Variable)) { + // Engine unwraps casts in comparisons in UnwrapCastInComparison. Within a connector we can do more than + // engine only for source columns. We cannot draw many conclusions for intermediate expressions without + // knowing them well. + return Optional.empty(); + } + + if (constant.getValue() == null) { + // Comparisons with NULL should be simplified by the engine + return Optional.empty(); + } + + IcebergColumnHandle column = resolve((Variable) castSource, assignments); + if (column.getType() instanceof TimestampWithTimeZoneType) { + // Iceberg supports only timestamp(6) with time zone + checkArgument(((TimestampWithTimeZoneType) column.getType()).getPrecision() == 6, "Unexpected type: %s", column.getType()); + + if (constant.getType() == DateType.DATE) { + return unwrapTimestampTzToDateCast(column, functionName, (long) constant.getValue()) + .map(domain -> TupleDomain.withColumnDomains(ImmutableMap.of(column, domain))); + } + // TODO support timestamp constant + } + + return Optional.empty(); + } + + private static Optional unwrapTimestampTzToDateCast(IcebergColumnHandle column, FunctionName functionName, long date) + { + Type type = column.getType(); + checkArgument(type.equals(TIMESTAMP_TZ_MICROS), "Column of unexpected type %s: %s ", type, column); + + // Verify no overflow. Date values must be in integer range. + verify(date <= Integer.MAX_VALUE, "Date value out of range: %s", date); + + // In Iceberg, timestamp with time zone values are all in UTC + + LongTimestampWithTimeZone startOfDate = LongTimestampWithTimeZone.fromEpochMillisAndFraction(date * MILLISECONDS_PER_DAY, 0, UTC_KEY); + LongTimestampWithTimeZone startOfNextDate = LongTimestampWithTimeZone.fromEpochMillisAndFraction((date + 1) * MILLISECONDS_PER_DAY, 0, UTC_KEY); + + if (functionName.equals(EQUAL_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.range(type, startOfDate, true, startOfNextDate, false)), false)); + } + if (functionName.equals(NOT_EQUAL_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.lessThan(type, startOfDate), Range.greaterThanOrEqual(type, startOfNextDate)), false)); + } + if (functionName.equals(LESS_THAN_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.lessThan(type, startOfDate)), false)); + } + if (functionName.equals(LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.lessThan(type, startOfNextDate)), false)); + } + if (functionName.equals(GREATER_THAN_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(type, startOfNextDate)), false)); + } + if (functionName.equals(GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(type, startOfDate)), false)); + } + if (functionName.equals(IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME)) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.lessThan(type, startOfDate), Range.greaterThanOrEqual(type, startOfNextDate)), true)); + } + + return Optional.empty(); + } + + private static IcebergColumnHandle resolve(Variable variable, Map assignments) + { + ColumnHandle columnHandle = assignments.get(variable.getName()); + checkArgument(columnHandle != null, "No assignment for %s", variable); + return (IcebergColumnHandle) columnHandle; + } + + public static class ExtractionResult + { + private final TupleDomain tupleDomain; + private final ConnectorExpression remainingExpression; + + public ExtractionResult(TupleDomain tupleDomain, ConnectorExpression remainingExpression) + { + this.tupleDomain = requireNonNull(tupleDomain, "tupleDomain is null"); + this.remainingExpression = requireNonNull(remainingExpression, "remainingExpression is null"); + } + + public TupleDomain getTupleDomain() + { + return tupleDomain; + } + + public ConnectorExpression getRemainingExpression() + { + return remainingExpression; + } + } +} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index 0397733f0621..5718a18face2 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -157,6 +157,7 @@ import static io.trino.plugin.hive.HiveApplyProjectionUtil.extractSupportedProjectedColumns; import static io.trino.plugin.hive.HiveApplyProjectionUtil.replaceWithNewVariables; import static io.trino.plugin.hive.util.HiveUtil.isStructuralType; +import static io.trino.plugin.iceberg.ConstraintExtractor.extractTupleDomain; import static io.trino.plugin.iceberg.ExpressionConverter.toIcebergExpression; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_UPDATE_ROW_ID_COLUMN_ID; import static io.trino.plugin.iceberg.IcebergColumnHandle.TRINO_UPDATE_ROW_ID_COLUMN_NAME; @@ -1719,43 +1720,54 @@ public void rollback() public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { IcebergTableHandle table = (IcebergTableHandle) handle; - TupleDomain predicate = constraint.getSummary(); + ConstraintExtractor.ExtractionResult extractionResult = extractTupleDomain(constraint); + TupleDomain predicate = extractionResult.getTupleDomain(); if (predicate.isAll()) { return Optional.empty(); } - Table icebergTable = catalog.loadTable(session, table.getSchemaTableName()); - - Map unsupported = new LinkedHashMap<>(); - Map newEnforced = new LinkedHashMap<>(); - Map newUnenforced = new LinkedHashMap<>(); - Map domains = predicate.getDomains().orElseThrow(() -> new IllegalArgumentException("constraint summary is NONE")); - domains.forEach((column, domain) -> { - IcebergColumnHandle columnHandle = (IcebergColumnHandle) column; - // Iceberg metadata columns can not be used to filter a table scan in Iceberg library - // TODO (https://github.com/trinodb/trino/issues/8759) structural types cannot be used to filter a table scan in Iceberg library. - if (isMetadataColumnId(columnHandle.getId()) || isStructuralType(columnHandle.getType()) || - // Iceberg orders UUID values differently than Trino (perhaps due to https://bugs.openjdk.org/browse/JDK-7025832), so allow only IS NULL / IS NOT NULL checks - (columnHandle.getType() == UUID && !(domain.isOnlyNull() || domain.getValues().isAll()))) { - unsupported.put(columnHandle, domain); - } - else if (canEnforceColumnConstraintInAllSpecs(typeOperators, icebergTable, columnHandle, domain)) { - newEnforced.put(columnHandle, domain); - } - else { - newUnenforced.put(columnHandle, domain); - } - }); + TupleDomain newEnforcedConstraint; + TupleDomain newUnenforcedConstraint; + TupleDomain remainingConstraint; + if (predicate.isNone()) { + // Engine does not pass none Constraint.summary. It can become none when combined with the expression and connector's domain knowledge. + newEnforcedConstraint = TupleDomain.none(); + newUnenforcedConstraint = TupleDomain.all(); + remainingConstraint = TupleDomain.all(); + } + else { + Table icebergTable = catalog.loadTable(session, table.getSchemaTableName()); + + Map unsupported = new LinkedHashMap<>(); + Map newEnforced = new LinkedHashMap<>(); + Map newUnenforced = new LinkedHashMap<>(); + Map domains = predicate.getDomains().orElseThrow(() -> new VerifyException("No domains")); + domains.forEach((columnHandle, domain) -> { + // Iceberg metadata columns can not be used to filter a table scan in Iceberg library + // TODO (https://github.com/trinodb/trino/issues/8759) structural types cannot be used to filter a table scan in Iceberg library. + if (isMetadataColumnId(columnHandle.getId()) || isStructuralType(columnHandle.getType()) || + // Iceberg orders UUID values differently than Trino (perhaps due to https://bugs.openjdk.org/browse/JDK-7025832), so allow only IS NULL / IS NOT NULL checks + (columnHandle.getType() == UUID && !(domain.isOnlyNull() || domain.getValues().isAll()))) { + unsupported.put(columnHandle, domain); + } + else if (canEnforceColumnConstraintInAllSpecs(typeOperators, icebergTable, columnHandle, domain)) { + newEnforced.put(columnHandle, domain); + } + else { + newUnenforced.put(columnHandle, domain); + } + }); - TupleDomain newEnforcedConstraint = TupleDomain.withColumnDomains(newEnforced).intersect(table.getEnforcedPredicate()); - TupleDomain newUnenforcedConstraint = TupleDomain.withColumnDomains(newUnenforced).intersect(table.getUnenforcedPredicate()); + newEnforcedConstraint = TupleDomain.withColumnDomains(newEnforced).intersect(table.getEnforcedPredicate()); + newUnenforcedConstraint = TupleDomain.withColumnDomains(newUnenforced).intersect(table.getUnenforcedPredicate()); + remainingConstraint = TupleDomain.withColumnDomains(newUnenforced).intersect(TupleDomain.withColumnDomains(unsupported)); + } if (newEnforcedConstraint.equals(table.getEnforcedPredicate()) && newUnenforcedConstraint.equals(table.getUnenforcedPredicate())) { return Optional.empty(); } - TupleDomain remainingConstraint = TupleDomain.withColumnDomains(newUnenforced).intersect(TupleDomain.withColumnDomains(unsupported)); return Optional.of(new ConstraintApplicationResult<>( new IcebergTableHandle( table.getSchemaName(), @@ -1776,6 +1788,7 @@ else if (canEnforceColumnConstraintInAllSpecs(typeOperators, icebergTable, colum table.isRecordScannedFiles(), table.getMaxScannedFileSize()), remainingConstraint.transformKeys(ColumnHandle.class::cast), + extractionResult.getRemainingExpression(), false)); } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index ceb7ca1993fb..a4bec15aca57 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -31,6 +31,8 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.ValuesNode; import io.trino.testing.BaseConnectorTest; import io.trino.testing.DataProviders; import io.trino.testing.MaterializedResult; @@ -95,8 +97,10 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.BROADCAST; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.testing.MaterializedResult.resultBuilder; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -1217,6 +1221,8 @@ public void testHourTransform() assertThat(query("SELECT * FROM test_hour_transform WHERE d >= DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform WHERE CAST(d AS date) >= DATE '2015-05-15'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_hour_transform WHERE d >= TIMESTAMP '2015-05-15 12:00:00'")) .isFullyPushedDown(); @@ -1279,6 +1285,8 @@ public void testDayTransformDate() assertThat(query("SELECT * FROM test_day_transform_date WHERE d >= DATE '2015-01-13'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_date WHERE CAST(d AS date) >= DATE '2015-01-13'")) + .isFullyPushedDown(); // d comparison with TIMESTAMP can be unwrapped assertThat(query("SELECT * FROM test_day_transform_date WHERE d >= TIMESTAMP '2015-01-13 00:00:00'")) @@ -1355,6 +1363,8 @@ public void testDayTransformTimestamp() assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d >= DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-05-15'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d >= TIMESTAMP '2015-05-15 00:00:00'")) .isFullyPushedDown(); @@ -1432,6 +1442,13 @@ public void testDayTransformTimestampWithTimeZone() assertThat(query("SELECT * FROM test_day_transform_timestamptz WHERE d >= with_timezone(DATE '2015-05-15', 'UTC')")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-05-15'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-05-15' AND d < TIMESTAMP '2015-05-15 02:00:00 Europe/Warsaw'")) + // Engine can eliminate the table scan after connector accepts the filter pushdown + .hasPlan(node(OutputNode.class, node(ValuesNode.class))) + .returnsEmptyResult(); + assertThat(query("SELECT * FROM test_day_transform_timestamptz WHERE d >= TIMESTAMP '2015-05-15 00:00:00 UTC'")) .isFullyPushedDown(); assertThat(query("SELECT * FROM test_day_transform_timestamptz WHERE d >= TIMESTAMP '2015-05-15 00:00:00.000001 UTC'")) @@ -1499,6 +1516,10 @@ public void testMonthTransformDate() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_month_transform_date WHERE d >= DATE '2020-06-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_date WHERE CAST(d AS date) >= DATE '2020-06-01'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_month_transform_date WHERE CAST(d AS date) >= DATE '2020-06-02'")) + .isNotFullyPushedDown(FilterNode.class); // d comparison with TIMESTAMP can be unwrapped assertThat(query("SELECT * FROM test_month_transform_date WHERE d >= TIMESTAMP '2015-06-01 00:00:00'")) @@ -1575,6 +1596,10 @@ public void testMonthTransformTimestamp() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE d >= DATE '2015-05-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-05-01'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-05-02'")) + .isNotFullyPushedDown(FilterNode.class); assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE d >= TIMESTAMP '2015-05-01 00:00:00'")) .isFullyPushedDown(); @@ -1652,6 +1677,15 @@ public void testMonthTransformTimestampWithTimeZone() assertThat(query("SELECT * FROM test_month_transform_timestamptz WHERE d >= with_timezone(DATE '2015-05-02', 'UTC')")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-05-01'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_month_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-05-02'")) + .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-05-01' AND d < TIMESTAMP '2015-05-01 02:00:00 Europe/Warsaw'")) + // Engine can eliminate the table scan after connector accepts the filter pushdown + .hasPlan(node(OutputNode.class, node(ValuesNode.class))) + .returnsEmptyResult(); + assertThat(query("SELECT * FROM test_month_transform_timestamptz WHERE d >= TIMESTAMP '2015-05-01 00:00:00 UTC'")) .isFullyPushedDown(); assertThat(query("SELECT * FROM test_month_transform_timestamptz WHERE d >= TIMESTAMP '2015-05-01 00:00:00.000001 UTC'")) @@ -1714,6 +1748,10 @@ public void testYearTransformDate() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_date WHERE d >= DATE '2015-01-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_date WHERE CAST(d AS date) >= DATE '2015-01-01'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_year_transform_date WHERE CAST(d AS date) >= DATE '2015-01-02'")) + .isNotFullyPushedDown(FilterNode.class); // d comparison with TIMESTAMP can be unwrapped assertThat(query("SELECT * FROM test_year_transform_date WHERE d >= TIMESTAMP '2015-01-01 00:00:00'")) @@ -1788,6 +1826,10 @@ public void testYearTransformTimestamp() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE d >= DATE '2015-01-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-01-01'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-01-02'")) + .isNotFullyPushedDown(FilterNode.class); assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE d >= TIMESTAMP '2015-01-01 00:00:00'")) .isFullyPushedDown(); @@ -1863,6 +1905,15 @@ public void testYearTransformTimestampWithTimeZone() assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d >= with_timezone(DATE '2015-01-02', 'UTC')")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-01-01'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-01-02'")) + .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-01-01' AND d < TIMESTAMP '2015-01-01 01:00:00 Europe/Warsaw'")) + // Engine can eliminate the table scan after connector accepts the filter pushdown + .hasPlan(node(OutputNode.class, node(ValuesNode.class))) + .returnsEmptyResult(); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d >= TIMESTAMP '2015-01-01 00:00:00 UTC'")) .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d >= TIMESTAMP '2015-01-01 00:00:00.000001 UTC'")) @@ -3668,6 +3719,21 @@ public void testOptimizeTimePartitionedTable(String dataType, String partitionin .as("file count after optimize date, after the optimize") .isEqualTo(expectedFilesAfterOptimize); + // Verify that WHERE CAST(p AS date) ... form works in non-UTC zone + assertUpdate( + Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey("Asia/Kathmandu")) + .build(), + "ALTER TABLE " + tableName + " EXECUTE optimize WHERE CAST(p AS date) >= " + optimizeDate); + + // Table state shouldn't change substantially (but files may be rewritten) + assertThat((long) computeScalar("SELECT count(DISTINCT \"$path\") FROM " + tableName + " WHERE p < " + optimizeDate)) + .as("file count before optimize date, after the second optimize") + .isEqualTo(filesBeforeOptimizeDate); + assertThat((long) computeScalar("SELECT count(DISTINCT \"$path\") FROM " + tableName + " WHERE p >= " + optimizeDate)) + .as("file count after optimize date, after the second optimize") + .isEqualTo(expectedFilesAfterOptimize); + assertUpdate("DROP TABLE " + tableName); } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java new file mode 100644 index 000000000000..566fa4b528b2 --- /dev/null +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestConstraintExtractor.java @@ -0,0 +1,258 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.iceberg; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.Constraint; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Constant; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.DateType; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.Type; +import io.trino.sql.planner.ConnectorExpressionTranslator; +import io.trino.sql.planner.LiteralEncoder; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.iterative.rule.UnwrapCastInComparison; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.SymbolReference; +import org.testng.annotations.Test; + +import java.time.LocalDate; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.plugin.iceberg.ColumnIdentity.primitiveColumnIdentity; +import static io.trino.plugin.iceberg.ConstraintExtractor.extractTupleDomain; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_DAY; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; +import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; +import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; +import static java.time.ZoneOffset.UTC; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestConstraintExtractor +{ + private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); + + private static final AtomicInteger nextColumnId = new AtomicInteger(1); + + private static final IcebergColumnHandle A_BIGINT = newPrimitiveColumn(BIGINT); + private static final IcebergColumnHandle A_TIMESTAMP_TZ = newPrimitiveColumn(TIMESTAMP_TZ_MICROS); + + @Test + public void testExtractSummary() + { + assertThat(extract( + new Constraint( + TupleDomain.withColumnDomains(Map.of(A_BIGINT, Domain.singleValue(BIGINT, 1L))), + Constant.TRUE, + Map.of(), + values -> { + throw new AssertionError("should not be called"); + }, + Set.of(A_BIGINT)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_BIGINT, Domain.singleValue(BIGINT, 1L)))); + } + + /** + * Test equivalent of {@link UnwrapCastInComparison} for {@link TimestampWithTimeZoneType}. + * {@link UnwrapCastInComparison} handles {@link DateType} and {@link TimestampType}, but cannot handle + * {@link TimestampWithTimeZoneType}. Such unwrap would not be monotonic. Within Iceberg, we know + * that {@link TimestampWithTimeZoneType} is always in UTC zone (point in time, with no time zone information), + * so we can unwrap. + */ + @Test + public void testExtractTimestampTzDateComparison() + { + String timestampTzColumnSymbol = "timestamp_tz_symbol"; + Cast castOfColumn = new Cast(new SymbolReference(timestampTzColumnSymbol), toSqlType(DATE)); + + LocalDate someDate = LocalDate.of(2005, 9, 10); + Expression someDateExpression = LITERAL_ENCODER.toExpression(TEST_SESSION, someDate.toEpochDay(), DATE); + + long startOfDateUtcEpochMillis = someDate.atStartOfDay().toEpochSecond(UTC) * MILLISECONDS_PER_SECOND; + LongTimestampWithTimeZone startOfDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis); + LongTimestampWithTimeZone startOfNextDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis + MILLISECONDS_PER_DAY); + + assertThat(extract( + constraint( + new ComparisonExpression(EQUAL, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.range(TIMESTAMP_TZ_MICROS, startOfDateUtc, true, startOfNextDateUtc, false))))); + + assertThat(extract( + constraint( + new ComparisonExpression(NOT_EQUAL, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain( + Range.lessThan(TIMESTAMP_TZ_MICROS, startOfDateUtc), + Range.greaterThanOrEqual(TIMESTAMP_TZ_MICROS, startOfNextDateUtc))))); + + assertThat(extract( + constraint( + new ComparisonExpression(LESS_THAN, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.lessThan(TIMESTAMP_TZ_MICROS, startOfDateUtc))))); + + assertThat(extract( + constraint( + new ComparisonExpression(LESS_THAN_OR_EQUAL, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.lessThan(TIMESTAMP_TZ_MICROS, startOfNextDateUtc))))); + + assertThat(extract( + constraint( + new ComparisonExpression(GREATER_THAN, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.greaterThanOrEqual(TIMESTAMP_TZ_MICROS, startOfNextDateUtc))))); + + assertThat(extract( + constraint( + new ComparisonExpression(GREATER_THAN_OR_EQUAL, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.greaterThanOrEqual(TIMESTAMP_TZ_MICROS, startOfDateUtc))))); + + assertThat(extract( + constraint( + new ComparisonExpression(IS_DISTINCT_FROM, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, Domain.create( + ValueSet.ofRanges( + Range.lessThan(TIMESTAMP_TZ_MICROS, startOfDateUtc), + Range.greaterThanOrEqual(TIMESTAMP_TZ_MICROS, startOfNextDateUtc)), + true)))); + } + + @Test + public void testIntersectSummaryAndExpressionExtraction() + { + String timestampTzColumnSymbol = "timestamp_tz_symbol"; + Cast castOfColumn = new Cast(new SymbolReference(timestampTzColumnSymbol), toSqlType(DATE)); + + LocalDate someDate = LocalDate.of(2005, 9, 10); + Expression someDateExpression = LITERAL_ENCODER.toExpression(TEST_SESSION, someDate.toEpochDay(), DATE); + + long startOfDateUtcEpochMillis = someDate.atStartOfDay().toEpochSecond(UTC) * MILLISECONDS_PER_SECOND; + LongTimestampWithTimeZone startOfDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis); + LongTimestampWithTimeZone startOfNextDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis + MILLISECONDS_PER_DAY); + LongTimestampWithTimeZone startOfNextNextDateUtc = timestampTzFromEpochMillis(startOfDateUtcEpochMillis + MILLISECONDS_PER_DAY * 2); + + assertThat(extract( + constraint( + TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.lessThan(TIMESTAMP_TZ_MICROS, startOfNextNextDateUtc)))), + new ComparisonExpression(NOT_EQUAL, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of( + A_TIMESTAMP_TZ, domain( + Range.lessThan(TIMESTAMP_TZ_MICROS, startOfDateUtc), + Range.range(TIMESTAMP_TZ_MICROS, startOfNextDateUtc, true, startOfNextNextDateUtc, false))))); + + assertThat(extract( + constraint( + TupleDomain.withColumnDomains(Map.of(A_TIMESTAMP_TZ, domain(Range.lessThan(TIMESTAMP_TZ_MICROS, startOfNextDateUtc)))), + new ComparisonExpression(GREATER_THAN, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.none()); + + assertThat(extract( + constraint( + TupleDomain.withColumnDomains(Map.of(A_BIGINT, Domain.singleValue(BIGINT, 1L))), + new ComparisonExpression(GREATER_THAN_OR_EQUAL, castOfColumn, someDateExpression), + Map.of(timestampTzColumnSymbol, A_TIMESTAMP_TZ)))) + .isEqualTo(TupleDomain.withColumnDomains(Map.of( + A_BIGINT, Domain.singleValue(BIGINT, 1L), + A_TIMESTAMP_TZ, domain(Range.greaterThanOrEqual(TIMESTAMP_TZ_MICROS, startOfDateUtc))))); + } + + private static IcebergColumnHandle newPrimitiveColumn(Type type) + { + int id = nextColumnId.getAndIncrement(); + return new IcebergColumnHandle( + primitiveColumnIdentity(id, "column_" + id), + type, + ImmutableList.of(), + type, + Optional.empty()); + } + + private static TupleDomain extract(Constraint constraint) + { + ConstraintExtractor.ExtractionResult result = extractTupleDomain(constraint); + assertThat(result.getRemainingExpression()) + .isEqualTo(Constant.TRUE); + return result.getTupleDomain(); + } + + private static Constraint constraint(Expression expression, Map assignments) + { + return constraint(TupleDomain.all(), expression, assignments); + } + + private static Constraint constraint(TupleDomain summary, Expression expression, Map assignments) + { + Map symbolTypes = assignments.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getType())); + ConnectorExpression connectorExpression = connectorExpression(expression, symbolTypes); + return new Constraint(summary, connectorExpression, ImmutableMap.copyOf(assignments)); + } + + private static ConnectorExpression connectorExpression(Expression expression, Map symbolTypes) + { + return ConnectorExpressionTranslator.translate( + TEST_SESSION, + expression, + createTestingTypeAnalyzer(PLANNER_CONTEXT), + TypeProvider.viewOf(symbolTypes.entrySet().stream() + .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))), + PLANNER_CONTEXT) + .orElseThrow(); + } + + private static LongTimestampWithTimeZone timestampTzFromEpochMillis(long epochMillis) + { + return LongTimestampWithTimeZone.fromEpochMillisAndFraction(epochMillis, 0, UTC_KEY); + } + + private static Domain domain(Range first, Range... rest) + { + return Domain.create(ValueSet.ofRanges(first, rest), false); + } +}