From 3180b70669bc89cd2cf41dec08026be4af696b0d Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Fri, 22 Apr 2022 13:29:01 +0200 Subject: [PATCH 1/4] Fix JDBC constant rewrites for NULL literal --- .../plugin/jdbc/expression/RewriteExactNumericConstant.java | 4 ++++ .../trino/plugin/jdbc/expression/RewriteVarcharConstant.java | 3 +++ 2 files changed, 7 insertions(+) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java index f3f83d690586..284ced0164c9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteExactNumericConstant.java @@ -46,6 +46,10 @@ public Pattern getPattern() @Override public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { + if (constant.getValue() == null) { + return Optional.empty(); + } + Type type = constant.getType(); if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT) { return Optional.of(Long.toString((long) constant.getValue())); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java index f63d1b399d68..94ada1e9ecd4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java @@ -40,6 +40,9 @@ public Pattern getPattern() public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { Slice slice = (Slice) constant.getValue(); + if (slice == null) { + return Optional.empty(); + } return Optional.of("'" + slice.toStringUtf8().replace("'", "''") + "'"); } } From 1483afb8c183d2bc801403510e4299d08fad5ecb Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 9 Mar 2022 15:49:15 +0100 Subject: [PATCH 2/4] Translate IN predicate to connector expression --- .../ConnectorExpressionTranslator.java | 90 +++++++++++++++++-- .../TestConnectorExpressionTranslator.java | 34 +++++++ .../spi/expression/StandardFunctions.java | 12 +++ 3 files changed, 127 insertions(+), 9 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 2084e9606fd1..1b060397fa7f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -28,6 +28,7 @@ import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.FunctionName; import io.trino.spi.expression.Variable; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Decimals; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; @@ -50,6 +51,8 @@ import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.GenericLiteral; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; @@ -80,11 +83,13 @@ import static io.trino.SystemSessionProperties.isComplexExpressionPushdown; import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; @@ -272,6 +277,10 @@ protected Optional translateCall(Call call) } } + if (IN_PREDICATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 2) { + return translateInPredicate(call.getArguments().get(0), call.getArguments().get(1)); + } + QualifiedName name = QualifiedName.of(call.getFunctionName().getName()); List argumentTypes = call.getArguments().stream() .map(argument -> argument.getType().getTypeSignature()) @@ -344,15 +353,8 @@ private Optional translateCast(Type type, ConnectorExpression expres private Optional translateLogicalExpression(LogicalExpression.Operator operator, List arguments) { - ImmutableList.Builder translatedArguments = ImmutableList.builderWithExpectedSize(arguments.size()); - for (ConnectorExpression argument : arguments) { - Optional translated = translate(argument); - if (translated.isEmpty()) { - return Optional.empty(); - } - translatedArguments.add(translated.get()); - } - return Optional.of(new LogicalExpression(operator, translatedArguments.build())); + Optional> translatedArguments = translateExpressions(arguments); + return translatedArguments.map(expressions -> new LogicalExpression(operator, expressions)); } private Optional translateComparison(ComparisonExpression.Operator operator, ConnectorExpression left, ConnectorExpression right) @@ -446,6 +448,46 @@ protected Optional translateLike(ConnectorExpression value, Connecto return Optional.empty(); } + + protected Optional translateInPredicate(ConnectorExpression value, ConnectorExpression values) + { + Optional translatedValue = translate(value); + Optional> translatedValues = extractExpressionsFromArrayCall(values); + + if (translatedValue.isPresent() && translatedValues.isPresent()) { + return Optional.of(new InPredicate(translatedValue.get(), new InListExpression(translatedValues.get()))); + } + + return Optional.empty(); + } + + protected Optional> extractExpressionsFromArrayCall(ConnectorExpression expression) + { + if (!(expression instanceof Call)) { + return Optional.empty(); + } + + Call call = (Call) expression; + if (!call.getFunctionName().equals(ARRAY_CONSTRUCTOR_FUNCTION_NAME)) { + return Optional.empty(); + } + + return translateExpressions(call.getArguments()); + } + + protected Optional> translateExpressions(List expressions) + { + ImmutableList.Builder translatedExpressions = ImmutableList.builderWithExpectedSize(expressions.size()); + for (ConnectorExpression expression : expressions) { + Optional translated = translate(expression); + if (translated.isEmpty()) { + return Optional.empty(); + } + translatedExpressions.add(translated.get()); + } + + return Optional.of(translatedExpressions.build()); + } } public static class SqlToConnectorExpressionTranslator @@ -760,6 +802,36 @@ protected Optional visitSubscriptExpression(SubscriptExpres return Optional.of(new FieldDereference(typeOf(node), translatedBase.get(), toIntExact(((LongLiteral) node.getIndex()).getValue() - 1))); } + @Override + protected Optional visitInPredicate(InPredicate node, Void context) + { + InListExpression valueList = (InListExpression) node.getValueList(); + Optional valueExpression = process(node.getValue()); + + if (valueExpression.isEmpty()) { + return Optional.empty(); + } + + ImmutableList.Builder values = ImmutableList.builderWithExpectedSize(valueList.getValues().size()); + for (Expression value : valueList.getValues()) { + // TODO: NULL should be eliminated on the engine side (within a rule) + if (value == null || value instanceof NullLiteral) { + return Optional.empty(); + } + + Optional processedValue = process(value); + + if (processedValue.isEmpty()) { + return Optional.empty(); + } + + values.add(processedValue.get()); + } + + ConnectorExpression arrayExpression = new Call(new ArrayType(typeOf(node.getValueList())), ARRAY_CONSTRUCTOR_FUNCTION_NAME, values.build()); + return Optional.of(new Call(typeOf(node), IN_PREDICATE_FUNCTION_NAME, List.of(valueExpression.get(), arrayExpression))); + } + @Override protected Optional visitExpression(Expression node, Void context) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index f19d68a7441e..0eea7ecde92a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -24,6 +24,7 @@ import io.trino.spi.expression.FunctionName; import io.trino.spi.expression.StandardFunctions; import io.trino.spi.expression.Variable; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.tree.ArithmeticBinaryExpression; @@ -34,6 +35,8 @@ import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; @@ -41,6 +44,7 @@ import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullIfExpression; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SubscriptExpression; @@ -59,6 +63,7 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; @@ -94,6 +99,8 @@ public class TestConnectorExpressionTranslator private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT); private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5))); private static final VarcharType VARCHAR_TYPE = createVarcharType(25); + private static final ArrayType VARCHAR_ARRAY_TYPE = new ArrayType(VARCHAR_TYPE); + private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); private static final Map symbols = ImmutableMap.builder() @@ -418,6 +425,33 @@ public void testTranslateRegularExpression() }); } + @Test + public void testTranslateIn() + { + String value = "value_1"; + assertTranslationRoundTrips( + new InPredicate( + new SymbolReference("varchar_symbol_1"), + new InListExpression(List.of(new SymbolReference("varchar_symbol_1"), new StringLiteral(value)))), + new Call( + BOOLEAN, + StandardFunctions.IN_PREDICATE_FUNCTION_NAME, + List.of( + new Variable("varchar_symbol_1", VARCHAR_TYPE), + new Call(VARCHAR_ARRAY_TYPE, ARRAY_CONSTRUCTOR_FUNCTION_NAME, + List.of( + new Variable("varchar_symbol_1", VARCHAR_TYPE), + new Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), createVarcharType(value.length()))))))); + + // IN (null) is not translated + assertTranslationToConnectorExpression( + TEST_SESSION, + new InPredicate( + new SymbolReference("varchar_symbol_1"), + new InListExpression(List.of(new SymbolReference("varchar_symbol_1"), new NullLiteral()))), + Optional.empty()); + } + private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression) { assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression); diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java index cd7694efb7c1..2784697f6575 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java @@ -82,4 +82,16 @@ private StandardFunctions() {} public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate"); public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern"); + + /** + * {@code $in(value, array)} returns {@code true} when value is equal to an element of the array, + * otherwise returns {@code NULL} when comparing value to an element of the array returns an + * indeterminate result, otherwise returns {@code false} + */ + public static final FunctionName IN_PREDICATE_FUNCTION_NAME = new FunctionName("$in"); + + /** + * $array creates instance of {@link Array Type} + */ + public static final FunctionName ARRAY_CONSTRUCTOR_FUNCTION_NAME = new FunctionName("$array"); } From 584231ad102dc5cca610ff91e6dcb5923b84205b Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Thu, 10 Mar 2022 13:27:24 +0100 Subject: [PATCH 3/4] Rewrite connector IN expression in PostgreSQL connector --- .../ConnectorExpressionPatterns.java | 5 ++ .../plugin/jdbc/expression/RewriteIn.java | 89 +++++++++++++++++++ .../plugin/postgresql/PostgreSqlClient.java | 2 + .../postgresql/TestPostgreSqlClient.java | 33 ++++++- .../TestPostgreSqlConnectorTest.java | 29 ++++++ 5 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java index ecfe3b28e2ee..7666673300f2 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java @@ -79,6 +79,11 @@ public static Pattern call() return Property.property("argumentCount", call -> call.getArguments().size()); } + public static Property> arguments() + { + return Property.property("arguments", Call::getArguments); + } + public static Property argument(int argument) { checkArgument(0 <= argument, "Invalid argument index: %s", argument); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java new file mode 100644 index 000000000000..0b81f689b502 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.arguments; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getDomainCompactionThreshold; +import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.lang.String.format; + +public class RewriteIn + implements ConnectorExpressionRule +{ + private static final Capture VALUE = newCapture(); + private static final Capture> EXPRESSIONS = newCapture(); + + private static final Pattern PATTERN = call() + .with(functionName().equalTo(IN_PREDICATE_FUNCTION_NAME)) + .with(type().equalTo(BOOLEAN)) + .with(argumentCount().equalTo(2)) + .with(argument(0).matching(expression().capturedAs(VALUE))) + .with(argument(1).matching(call().with(functionName().equalTo(ARRAY_CONSTRUCTOR_FUNCTION_NAME)).with(arguments().capturedAs(EXPRESSIONS)))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(Call call, Captures captures, RewriteContext context) + { + Optional value = context.defaultRewrite(captures.get(VALUE)); + if (value.isEmpty()) { + return Optional.empty(); + } + + List expressions = captures.get(EXPRESSIONS); + if (expressions.size() > getDomainCompactionThreshold(context.getSession())) { + // We don't want to push down too long IN query text + return Optional.empty(); + } + + ImmutableList.Builder rewrittenValues = ImmutableList.builderWithExpectedSize(expressions.size()); + for (ConnectorExpression expression : expressions) { + Optional rewrittenExpression = context.defaultRewrite(expression); + if (rewrittenExpression.isEmpty()) { + return Optional.empty(); + } + rewrittenValues.add(rewrittenExpression.get()); + } + + List values = rewrittenValues.build(); + verify(!values.isEmpty(), "Empty values"); + return Optional.of(format("(%s) IN (%s)", value.get(), Joiner.on(", ").join(values))); + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 4dbf44840cd6..67dd6668e0e6 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -65,6 +65,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteComparison; +import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; import io.trino.spi.TrinoException; @@ -300,6 +301,7 @@ public PostgreSqlClient( .addStandardRules(this::quoted) // TODO allow all comparison operators for numeric types .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index a9d6bca510bd..7aae0329b414 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -20,11 +20,14 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcMetadataConfig; +import io.trino.plugin.jdbc.JdbcMetadataSessionProperties; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; @@ -36,6 +39,8 @@ import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; @@ -44,6 +49,7 @@ import io.trino.sql.tree.NullIfExpression; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; +import io.trino.testing.TestingConnectorSession; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -63,7 +69,6 @@ import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.testing.DataProviders.toDataProvider; -import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.testing.assertions.Assert.assertEquals; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static java.lang.String.format; @@ -93,6 +98,13 @@ public class TestPostgreSqlClient .setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty())) .build(); + private static final JdbcColumnHandle VARCHAR_COLUMN2 = + JdbcColumnHandle.builder() + .setColumnName("c_varchar2") + .setColumnType(createVarcharType(10)) + .setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty())) + .build(); + private static final JdbcClient JDBC_CLIENT = new PostgreSqlClient( new BaseJdbcConfig(), new PostgreSqlConfig(), @@ -104,6 +116,11 @@ public class TestPostgreSqlClient private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); + private static final ConnectorSession SESSION = TestingConnectorSession + .builder() + .setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig(), Optional.empty()).getSessionProperties()) + .build(); + @Test public void testImplementCount() { @@ -410,6 +427,20 @@ public void testConvertNotExpression() .hasValue("NOT ((\"c_varchar\") IS NOT NULL)"); } + @Test + public void testConvertIn() + { + assertThat(JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new InPredicate( + new SymbolReference("c_varchar"), + new InListExpression(List.of(new StringLiteral("value1"), new StringLiteral("value2"), new SymbolReference("c_varchar2")))), + Map.of("c_varchar", VARCHAR_COLUMN.getColumnType(), "c_varchar2", VARCHAR_COLUMN2.getColumnType())), + Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN, VARCHAR_COLUMN2.getColumnName(), VARCHAR_COLUMN2))) + .hasValue("(\"c_varchar\") IN ('value1', 'value2', \"c_varchar2\")"); + } + private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) { return ConnectorExpressionTranslator.translate( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index cae377cbb433..55ad0c437314 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -811,6 +811,35 @@ public void testNotExpressionPushdown() } } + @Test + public void testInPredicatePushdown() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_in_predicate_pushdown", + "(id varchar(1), id2 varchar(1))", + List.of( + "'a', 'b'", + "'b', 'c'", + "'c', 'c'", + "'d', 'd'", + "'a', 'f'"))) { + // IN values cannot be represented as a domain + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', id2)")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'b') OR id2 IN ('c', 'd')")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B') OR id2 IN ('c', 'D')")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B', NULL) OR id2 IN ('C', 'd')")) + // NULL constant value is currently not pushed down + .isNotFullyPushedDown(FilterNode.class); + } + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) { From 57fc12ac8405c447ea6ae4389d9a605d21cf37c1 Mon Sep 17 00:00:00 2001 From: Assaf Bern Date: Mon, 25 Jul 2022 10:07:32 +0300 Subject: [PATCH 4/4] Empty Commit