Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert discrete domain to Iceberg IN expression #10032

Merged
merged 2 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
import io.airlift.slice.Slice;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.SortedRangeSet;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
Expand All @@ -37,13 +35,15 @@
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expressions;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.plugin.iceberg.util.Timestamps.timestampTzToMicros;
import static io.trino.spi.type.TimeType.TIME_MICROS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS;
Expand All @@ -52,18 +52,19 @@
import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.apache.iceberg.expressions.Expressions.alwaysFalse;
import static org.apache.iceberg.expressions.Expressions.alwaysTrue;
import static org.apache.iceberg.expressions.Expressions.and;
import static org.apache.iceberg.expressions.Expressions.equal;
import static org.apache.iceberg.expressions.Expressions.greaterThan;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.isNull;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.not;
import static org.apache.iceberg.expressions.Expressions.or;

public final class ExpressionConverter
{
Expand Down Expand Up @@ -106,23 +107,25 @@ private static Expression toIcebergExpression(String columnName, Type type, Doma
throw new UnsupportedOperationException("Unsupported type for expression: " + type);
}

ValueSet domainValues = domain.getValues();
Expression expression = null;
if (domain.isNullAllowed()) {
expression = isNull(columnName);
}

if (domainValues instanceof SortedRangeSet) {
List<Range> orderedRanges = ((SortedRangeSet) domainValues).getOrderedRanges();
expression = firstNonNull(expression, alwaysFalse());

if (type.isOrderable()) {
List<Range> orderedRanges = domain.getValues().getRanges().getOrderedRanges();
List<Object> icebergValues = new ArrayList<>();
List<Expression> rangeExpressions = new ArrayList<>();
for (Range range : orderedRanges) {
expression = or(expression, toIcebergExpression(columnName, range));
if (range.isSingleValue()) {
icebergValues.add(getIcebergLiteralValue(type, range.getLowBoundedValue()));
}
else {
rangeExpressions.add(toIcebergExpression(columnName, range));
}
}
return expression;
Expression ranges = or(rangeExpressions);
Expression values = icebergValues.isEmpty() ? alwaysFalse() : in(columnName, icebergValues);
Expression nullExpression = domain.isNullAllowed() ? isNull(columnName) : alwaysFalse();
return or(nullExpression, or(values, ranges));
}

throw new VerifyException("Did not expect a domain value set other than SortedRangeSet but got " + domainValues.getClass().getSimpleName());
throw new VerifyException(format("Unsupported type %s with domain values %s", type, domain));
}

private static Expression toIcebergExpression(String columnName, Range range)
Expand Down Expand Up @@ -228,4 +231,21 @@ private static Object getIcebergLiteralValue(Type type, Object trinoNativeValue)

throw new UnsupportedOperationException("Unsupported type: " + type);
}

private static Expression or(Expression left, Expression right)
{
return Expressions.or(left, right);
}

private static Expression or(List<Expression> expressions)
{
if (expressions.isEmpty()) {
return alwaysFalse();
}
if (expressions.size() == 1) {
return getOnlyElement(expressions);
}
int mid = expressions.size() / 2;
return or(or(expressions.subList(0, mid)), or(expressions.subList(mid, expressions.size())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,6 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
Set<Integer> partitionSourceIds = identityPartitionColumnsInAllSpecs(icebergTable);
BiPredicate<IcebergColumnHandle, Domain> isIdentityPartition = (column, domain) -> partitionSourceIds.contains(column.getId());

// TODO: Avoid enforcing the constraint when partition filters have large IN expressions, since iceberg cannot
// support it. Such large expressions cannot be simplified since simplification changes the filtered set.
TupleDomain<IcebergColumnHandle> newEnforcedConstraint = constraint.getSummary()
.transformKeys(IcebergColumnHandle.class::cast)
.filter(isIdentityPartition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ public void testSchemaEvolution()
}

@Test
public void testLargeInFailureOnPartitionedColumns()
public void testLargeInOnPartitionedColumns()
{
assertUpdate("CREATE TABLE test_large_in_failure (col1 BIGINT, col2 BIGINT) WITH (partitioning = ARRAY['col2'])");
assertUpdate("INSERT INTO test_large_in_failure VALUES (1, 10)", 1L);
Expand All @@ -1076,11 +1076,9 @@ public void testLargeInFailureOnPartitionedColumns()
List<String> predicates = IntStream.range(0, 25_000).boxed()
.map(Object::toString)
.collect(toImmutableList());

String filter = format("col2 IN (%s)", join(",", predicates));
assertThatThrownBy(() -> getQueryRunner().execute(format("SELECT * FROM test_large_in_failure WHERE %s", filter)))
.isInstanceOf(RuntimeException.class)
.hasMessage("java.lang.StackOverflowError");
assertThat(query("SELECT * FROM test_large_in_failure WHERE " + filter))
.matches("TABLE test_large_in_failure");

dropTable("test_large_in_failure");
}
Expand Down