Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate IN predicate to connector expression - continue #13136

Merged
merged 4 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
Expand All @@ -50,6 +51,8 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
Expand Down Expand Up @@ -80,11 +83,13 @@
import static io.trino.SystemSessionProperties.isComplexExpressionPushdown;
import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME;
Expand Down Expand Up @@ -272,6 +277,10 @@ protected Optional<Expression> translateCall(Call call)
}
}

if (IN_PREDICATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 2) {
return translateInPredicate(call.getArguments().get(0), call.getArguments().get(1));
}

QualifiedName name = QualifiedName.of(call.getFunctionName().getName());
List<TypeSignature> argumentTypes = call.getArguments().stream()
.map(argument -> argument.getType().getTypeSignature())
Expand Down Expand Up @@ -344,15 +353,8 @@ private Optional<Expression> translateCast(Type type, ConnectorExpression expres

private Optional<Expression> translateLogicalExpression(LogicalExpression.Operator operator, List<ConnectorExpression> arguments)
{
ImmutableList.Builder<Expression> translatedArguments = ImmutableList.builderWithExpectedSize(arguments.size());
for (ConnectorExpression argument : arguments) {
Optional<Expression> translated = translate(argument);
if (translated.isEmpty()) {
return Optional.empty();
}
translatedArguments.add(translated.get());
}
return Optional.of(new LogicalExpression(operator, translatedArguments.build()));
Optional<List<Expression>> translatedArguments = translateExpressions(arguments);
return translatedArguments.map(expressions -> new LogicalExpression(operator, expressions));
}

private Optional<Expression> translateComparison(ComparisonExpression.Operator operator, ConnectorExpression left, ConnectorExpression right)
Expand Down Expand Up @@ -446,6 +448,46 @@ protected Optional<Expression> translateLike(ConnectorExpression value, Connecto

return Optional.empty();
}

protected Optional<Expression> translateInPredicate(ConnectorExpression value, ConnectorExpression values)
{
Optional<Expression> translatedValue = translate(value);
Optional<List<Expression>> translatedValues = extractExpressionsFromArrayCall(values);

if (translatedValue.isPresent() && translatedValues.isPresent()) {
return Optional.of(new InPredicate(translatedValue.get(), new InListExpression(translatedValues.get())));
}

return Optional.empty();
}

protected Optional<List<Expression>> extractExpressionsFromArrayCall(ConnectorExpression expression)
{
if (!(expression instanceof Call)) {
return Optional.empty();
}

Call call = (Call) expression;
if (!call.getFunctionName().equals(ARRAY_CONSTRUCTOR_FUNCTION_NAME)) {
return Optional.empty();
}

return translateExpressions(call.getArguments());
}

protected Optional<List<Expression>> translateExpressions(List<ConnectorExpression> expressions)
{
ImmutableList.Builder<Expression> translatedExpressions = ImmutableList.builderWithExpectedSize(expressions.size());
for (ConnectorExpression expression : expressions) {
Optional<Expression> translated = translate(expression);
if (translated.isEmpty()) {
return Optional.empty();
}
translatedExpressions.add(translated.get());
}

return Optional.of(translatedExpressions.build());
}
}

public static class SqlToConnectorExpressionTranslator
Expand Down Expand Up @@ -760,6 +802,36 @@ protected Optional<ConnectorExpression> visitSubscriptExpression(SubscriptExpres
return Optional.of(new FieldDereference(typeOf(node), translatedBase.get(), toIntExact(((LongLiteral) node.getIndex()).getValue() - 1)));
}

@Override
protected Optional<ConnectorExpression> visitInPredicate(InPredicate node, Void context)
{
InListExpression valueList = (InListExpression) node.getValueList();
Optional<ConnectorExpression> valueExpression = process(node.getValue());

if (valueExpression.isEmpty()) {
return Optional.empty();
}

ImmutableList.Builder<ConnectorExpression> values = ImmutableList.builderWithExpectedSize(valueList.getValues().size());
for (Expression value : valueList.getValues()) {
// TODO: NULL should be eliminated on the engine side (within a rule)
if (value == null || value instanceof NullLiteral) {
return Optional.empty();
}

Optional<ConnectorExpression> processedValue = process(value);

if (processedValue.isEmpty()) {
return Optional.empty();
}

values.add(processedValue.get());
}

ConnectorExpression arrayExpression = new Call(new ArrayType(typeOf(node.getValueList())), ARRAY_CONSTRUCTOR_FUNCTION_NAME, values.build());
return Optional.of(new Call(typeOf(node), IN_PREDICATE_FUNCTION_NAME, List.of(valueExpression.get(), arrayExpression)));
}

@Override
protected Optional<ConnectorExpression> visitExpression(Expression node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.tree.ArithmeticBinaryExpression;
Expand All @@ -34,13 +35,16 @@
import io.trino.sql.tree.DoubleLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullIfExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.SubscriptExpression;
Expand All @@ -59,6 +63,7 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
Expand Down Expand Up @@ -94,6 +99,8 @@ public class TestConnectorExpressionTranslator
private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT);
private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5)));
private static final VarcharType VARCHAR_TYPE = createVarcharType(25);
private static final ArrayType VARCHAR_ARRAY_TYPE = new ArrayType(VARCHAR_TYPE);

private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT);

private static final Map<Symbol, Type> symbols = ImmutableMap.<Symbol, Type>builder()
Expand Down Expand Up @@ -418,6 +425,33 @@ public void testTranslateRegularExpression()
});
}

@Test
public void testTranslateIn()
{
String value = "value_1";
assertTranslationRoundTrips(
new InPredicate(
new SymbolReference("varchar_symbol_1"),
new InListExpression(List.of(new SymbolReference("varchar_symbol_1"), new StringLiteral(value)))),
new Call(
BOOLEAN,
StandardFunctions.IN_PREDICATE_FUNCTION_NAME,
List.of(
new Variable("varchar_symbol_1", VARCHAR_TYPE),
new Call(VARCHAR_ARRAY_TYPE, ARRAY_CONSTRUCTOR_FUNCTION_NAME,
List.of(
new Variable("varchar_symbol_1", VARCHAR_TYPE),
new Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), createVarcharType(value.length())))))));

// IN (null) is not translated
assertTranslationToConnectorExpression(
TEST_SESSION,
new InPredicate(
new SymbolReference("varchar_symbol_1"),
new InListExpression(List.of(new SymbolReference("varchar_symbol_1"), new NullLiteral()))),
Optional.empty());
}

private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression)
{
assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,16 @@ private StandardFunctions() {}
public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate");

public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern");

/**
* {@code $in(value, array)} returns {@code true} when value is equal to an element of the array,
* otherwise returns {@code NULL} when comparing value to an element of the array returns an
* indeterminate result, otherwise returns {@code false}
*/
public static final FunctionName IN_PREDICATE_FUNCTION_NAME = new FunctionName("$in");

/**
* $array creates instance of {@link Array Type}
*/
public static final FunctionName ARRAY_CONSTRUCTOR_FUNCTION_NAME = new FunctionName("$array");
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ public static Pattern<Call> call()
return Property.property("argumentCount", call -> call.getArguments().size());
}

public static Property<Call, ?, List<ConnectorExpression>> arguments()
{
return Property.property("arguments", Call::getArguments);
}

public static Property<Call, ?, ConnectorExpression> argument(int argument)
{
checkArgument(0 <= argument, "Invalid argument index: %s", argument);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public Pattern<Constant> getPattern()
@Override
public Optional<String> rewrite(Constant constant, Captures captures, RewriteContext<String> context)
{
if (constant.getValue() == null) {
return Optional.empty();
}

Type type = constant.getType();
if (type == TINYINT || type == SMALLINT || type == INTEGER || type == BIGINT) {
return Optional.of(Long.toString((long) constant.getValue()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.arguments;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getDomainCompactionThreshold;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.lang.String.format;

public class RewriteIn
implements ConnectorExpressionRule<Call, String>
{
private static final Capture<ConnectorExpression> VALUE = newCapture();
private static final Capture<List<ConnectorExpression>> EXPRESSIONS = newCapture();

private static final Pattern<Call> PATTERN = call()
.with(functionName().equalTo(IN_PREDICATE_FUNCTION_NAME))
.with(type().equalTo(BOOLEAN))
.with(argumentCount().equalTo(2))
.with(argument(0).matching(expression().capturedAs(VALUE)))
.with(argument(1).matching(call().with(functionName().equalTo(ARRAY_CONSTRUCTOR_FUNCTION_NAME)).with(arguments().capturedAs(EXPRESSIONS))));

@Override
public Pattern<Call> getPattern()
{
return PATTERN;
}

@Override
public Optional<String> rewrite(Call call, Captures captures, RewriteContext<String> context)
{
Optional<String> value = context.defaultRewrite(captures.get(VALUE));
if (value.isEmpty()) {
return Optional.empty();
}

List<ConnectorExpression> expressions = captures.get(EXPRESSIONS);
if (expressions.size() > getDomainCompactionThreshold(context.getSession())) {
// We don't want to push down too long IN query text
return Optional.empty();
}

ImmutableList.Builder<String> rewrittenValues = ImmutableList.builderWithExpectedSize(expressions.size());
for (ConnectorExpression expression : expressions) {
Optional<String> rewrittenExpression = context.defaultRewrite(expression);
if (rewrittenExpression.isEmpty()) {
return Optional.empty();
}
rewrittenValues.add(rewrittenExpression.get());
}

List<String> values = rewrittenValues.build();
verify(!values.isEmpty(), "Empty values");
return Optional.of(format("(%s) IN (%s)", value.get(), Joiner.on(", ").join(values)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public Pattern<Constant> getPattern()
public Optional<String> rewrite(Constant constant, Captures captures, RewriteContext<String> context)
{
Slice slice = (Slice) constant.getValue();
if (slice == null) {
return Optional.empty();
assaf2 marked this conversation as resolved.
Show resolved Hide resolved
}
return Optional.of("'" + slice.toStringUtf8().replace("'", "''") + "'");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp;
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
import io.trino.plugin.jdbc.expression.RewriteComparison;
import io.trino.plugin.jdbc.expression.RewriteIn;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping;
import io.trino.spi.TrinoException;
Expand Down Expand Up @@ -300,6 +301,7 @@ public PostgreSqlClient(
.addStandardRules(this::quoted)
// TODO allow all comparison operators for numeric types
.add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL)))
.add(new RewriteIn())
.withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint"))
.map("$add(left: integer_type, right: integer_type)").to("left + right")
.map("$subtract(left: integer_type, right: integer_type)").to("left - right")
Expand Down
Loading