Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support integer arithmetic, IS (NOT) NULL, NULLIF pushdown in MySQL connector #20378

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions plugin/trino-mysql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-parser</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-plugin-toolkit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ public MySqlClient(

this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint"))
// No "real" on the list; pushdown on REAL is disabled also in toColumnMapping
.withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "double"))
.map("$equal(left: numeric_type, right: numeric_type)").to("left = right")
Expand All @@ -279,8 +280,17 @@ public MySqlClient(
.map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right")
.map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right")
.map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right")
.map("$add(left: integer_type, right: integer_type)").to("left + right")
.map("$subtract(left: integer_type, right: integer_type)").to("left - right")
.map("$multiply(left: integer_type, right: integer_type)").to("left * right")
.map("$divide(left: integer_type, right: integer_type)").to("left / right")
.map("$modulus(left: integer_type, right: integer_type)").to("left % right")
.map("$negate(value: integer_type)").to("-value")
.add(new RewriteLikeWithCaseSensitivity())
.add(new RewriteLikeEscapeWithCaseSensitivity())
.map("$not($is_null(value))").to("value IS NOT NULL")
.map("$is_null(value)").to("value IS NULL")
.map("$nullif(first, second)").to("NULLIF(first, second)")
.build();

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.mysql;

import com.google.common.collect.ImmutableMap;
import io.trino.plugin.base.mapping.DefaultIdentifierMapping;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
Expand All @@ -22,23 +23,43 @@
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.Type;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.NullIfExpression;
import io.trino.sql.tree.SymbolReference;
import org.junit.jupiter.api.Test;

import java.sql.Types;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static java.lang.String.format;
import static org.assertj.core.api.Assertions.assertThat;

public class TestMySqlClient
Expand All @@ -57,6 +78,13 @@ public class TestMySqlClient
.setJdbcTypeHandle(new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()))
.build();

private static final JdbcColumnHandle VARCHAR_COLUMN =
JdbcColumnHandle.builder()
.setColumnName("c_varchar")
.setColumnType(createVarcharType(10))
.setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty()))
.build();

private static final JdbcClient JDBC_CLIENT = new MySqlClient(
new BaseJdbcConfig(),
new JdbcStatisticsConfig(),
Expand All @@ -68,6 +96,8 @@ public class TestMySqlClient
new DefaultIdentifierMapping(),
RemoteQueryModifier.NONE);

private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT);

@Test
public void testImplementCount()
{
Expand Down Expand Up @@ -169,4 +199,99 @@ private static void testImplementAggregation(AggregateFunction aggregateFunction
.isEqualTo(aggregateFunction.getOutputType());
}
}

@Test
public void testConvertArithmeticBinary()
{
for (ArithmeticBinaryExpression.Operator operator : ArithmeticBinaryExpression.Operator.values()) {
ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(
SESSION,
translateToConnectorExpression(
new ArithmeticBinaryExpression(
operator,
new SymbolReference("c_bigint_symbol"),
LITERAL_ENCODER.toExpression(42L, BIGINT)),
Map.of("c_bigint_symbol", BIGINT)),
Map.of("c_bigint_symbol", BIGINT_COLUMN))
.orElseThrow();

assertThat(converted.expression()).isEqualTo(format("(`c_bigint`) %s (?)", operator.getValue()));
assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(42L))));
}
}

@Test
public void testConvertArithmeticUnaryMinus()
{
ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(
SESSION,
translateToConnectorExpression(
new ArithmeticUnaryExpression(
ArithmeticUnaryExpression.Sign.MINUS,
new SymbolReference("c_bigint_symbol")),
Map.of("c_bigint_symbol", BIGINT)),
Map.of("c_bigint_symbol", BIGINT_COLUMN))
.orElseThrow();

assertThat(converted.expression()).isEqualTo("-(`c_bigint`)");
assertThat(converted.parameters()).isEqualTo(List.of());
}

@Test
public void testConvertIsNull()
{
// c_varchar IS NULL
ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
translateToConnectorExpression(
new IsNullPredicate(
new SymbolReference("c_varchar_symbol")),
Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
Map.of("c_varchar_symbol", VARCHAR_COLUMN))
.orElseThrow();
assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NULL");
assertThat(converted.parameters()).isEqualTo(List.of());
}

@Test
public void testConvertIsNotNull()
{
// c_varchar IS NOT NULL
ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
translateToConnectorExpression(
new IsNotNullPredicate(
new SymbolReference("c_varchar_symbol")),
Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
Map.of("c_varchar_symbol", VARCHAR_COLUMN))
.orElseThrow();
assertThat(converted.expression()).isEqualTo("(`c_varchar`) IS NOT NULL");
assertThat(converted.parameters()).isEqualTo(List.of());
}

@Test
public void testConvertNullIf()
{
// nullif(a_varchar, b_varchar)
ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION,
translateToConnectorExpression(
new NullIfExpression(
new SymbolReference("a_varchar_symbol"),
new SymbolReference("b_varchar_symbol")),
ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN.getColumnType(), "b_varchar_symbol", VARCHAR_COLUMN.getColumnType())),
ImmutableMap.of("a_varchar_symbol", VARCHAR_COLUMN, "b_varchar_symbol", VARCHAR_COLUMN))
.orElseThrow();
assertThat(converted.expression()).isEqualTo("NULLIF((`c_varchar`), (`c_varchar`))");
assertThat(converted.parameters()).isEqualTo(List.of());
}

private ConnectorExpression translateToConnectorExpression(Expression expression, Map<String, Type> symbolTypes)
{
return ConnectorExpressionTranslator.translate(
TEST_SESSION,
expression,
TypeProvider.viewOf(symbolTypes.entrySet().stream()
.collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))),
PLANNER_CONTEXT,
createTestingTypeAnalyzer(PLANNER_CONTEXT))
.orElseThrow();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

import com.google.common.collect.ImmutableMap;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
import io.trino.testing.sql.TestTable;
import org.junit.jupiter.api.Test;

import java.util.List;

import static com.google.common.base.Verify.verify;
import static io.trino.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -30,9 +36,80 @@ protected QueryRunner createQueryRunner()
return createMySqlQueryRunner(mySqlServer, ImmutableMap.of(), ImmutableMap.of(), REQUIRED_TPCH_TABLES);
}

@Override
protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
return switch (connectorBehavior) {
case SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN -> {
// TODO remove once super has this set to true
verify(!super.hasBehavior(connectorBehavior));
yield true;
}
default -> super.hasBehavior(connectorBehavior);
};
}

@Override
protected void verifyColumnNameLengthFailurePermissible(Throwable e)
{
assertThat(e).hasMessageMatching("(Incorrect column name '.*'|Identifier name '.*' is too long)");
}

@Test
public void testIsNullPredicatePushdown()
{
assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL")).isFullyPushedDown();
assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL OR regionkey = 4")).isFullyPushedDown();

try (TestTable table = new TestTable(
getQueryRunner()::execute,
"test_is_null_predicate_pushdown",
"(a_int integer, a_varchar varchar(1))",
List.of(
"1, 'A'",
"2, 'B'",
"1, NULL",
"2, NULL"))) {
assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NULL OR a_int = 1")).isFullyPushedDown();
}
}

@Test
public void testIsNotNullPredicatePushdown()
{
assertThat(query("SELECT nationkey FROM nation WHERE name IS NOT NULL OR regionkey = 4")).isFullyPushedDown();

try (TestTable table = new TestTable(
getQueryRunner()::execute,
"test_is_not_null_predicate_pushdown",
"(a_int integer, a_varchar varchar(1))",
List.of(
"1, 'A'",
"2, 'B'",
"1, NULL",
"2, NULL"))) {
assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NOT NULL OR a_int = 1")).isFullyPushedDown();
}
}

@Test
public void testNullIfPredicatePushdown()
{
assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'ALGERIA') IS NULL"))
.matches("VALUES BIGINT '0'")
.isFullyPushedDown();

assertThat(query("SELECT name FROM nation WHERE NULLIF(nationkey, 0) IS NULL"))
.matches("VALUES CAST('ALGERIA' AS varchar(255))")
.isFullyPushedDown();

assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Algeria') IS NULL"))
.returnsEmptyResult()
.isFullyPushedDown();

// NULLIF returns the first argument because arguments aren't the same
assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Name not found') = name"))
.matches("SELECT nationkey FROM nation")
.isFullyPushedDown();
}
}
Loading