Skip to content

Commit

Permalink
Add MySQL case sensitive collation varchar push down support
Browse files Browse the repository at this point in the history
When the column uses a case-sensitive collation in MySQL we can
pushdown the equality/inequality predicates on text columns without
affecting correctness.

Range predicates on text columns are never pushed down and
equality/inequality predicates on text columns are not pushed down if
the column uses a case-insensitive collation.
  • Loading branch information
vlad-lyutenko authored and Praveen2112 committed Aug 1, 2023
1 parent fa4a067 commit 7e9513e
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
Expand All @@ -27,6 +28,7 @@
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.CaseSensitivity;
import io.trino.plugin.jdbc.ColumnMapping;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.JdbcColumnHandle;
Expand All @@ -38,6 +40,7 @@
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongReadFunction;
import io.trino.plugin.jdbc.LongWriteFunction;
import io.trino.plugin.jdbc.PredicatePushdownController;
import io.trino.plugin.jdbc.PreparedQuery;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteTableName;
Expand All @@ -64,6 +67,9 @@
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.TableNotFoundException;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
Expand All @@ -85,6 +91,7 @@
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.SQLSyntaxErrorException;
import java.sql.Types;
Expand All @@ -98,6 +105,7 @@
import java.util.function.BiFunction;
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -106,27 +114,32 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static com.mysql.cj.exceptions.MysqlErrorNumbers.ER_NO_SUCH_TABLE;
import static com.mysql.cj.exceptions.MysqlErrorNumbers.ER_UNKNOWN_TABLE;
import static com.mysql.cj.exceptions.MysqlErrorNumbers.SQL_STATE_ER_TABLE_EXISTS_ERROR;
import static io.airlift.json.JsonCodec.jsonCodec;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_INSENSITIVE;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE;
import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRoundingMode;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getDomainCompactionThreshold;
import static io.trino.plugin.jdbc.PredicatePushdownController.CASE_INSENSITIVE_CHARACTER_PUSHDOWN;
import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN;
import static io.trino.plugin.jdbc.PredicatePushdownController.FULL_PUSHDOWN;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.dateReadFunctionUsingLocalDate;
import static io.trino.plugin.jdbc.StandardColumnMappings.decimalColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.defaultCharColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.defaultVarcharColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction;
Expand All @@ -146,13 +159,15 @@
import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryReadFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharReadFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction;
import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling;
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.CharType.createCharType;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.DecimalType.createDecimalType;
import static io.trino.spi.type.DoubleType.DOUBLE;
Expand All @@ -163,6 +178,8 @@
import static io.trino.spi.type.TimestampType.createTimestampType;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.createUnboundedVarcharType;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static java.lang.Float.floatToRawIntBits;
import static java.lang.Math.max;
import static java.lang.Math.min;
Expand Down Expand Up @@ -197,6 +214,26 @@ public class MySqlClient
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;

private static final PredicatePushdownController MYSQL_CHARACTER_PUSHDOWN = (session, domain) -> {
if (domain.isNullableSingleValue()) {
return FULL_PUSHDOWN.apply(session, domain);
}

Domain simplifiedDomain = domain.simplify(getDomainCompactionThreshold(session));
if (!simplifiedDomain.getValues().isDiscreteSet()) {
// Push down inequality predicate
ValueSet complement = simplifiedDomain.getValues().complement();
if (complement.isDiscreteSet()) {
return FULL_PUSHDOWN.apply(session, simplifiedDomain);
}
// Domain#simplify can turn a discrete set into a range predicate
// Push down of range predicate for varchar/char types could lead to incorrect results
// when the remote database is case insensitive
return DISABLE_PUSHDOWN.apply(session, domain);
}
return FULL_PUSHDOWN.apply(session, simplifiedDomain);
};

@Inject
public MySqlClient(
BaseJdbcConfig config,
Expand Down Expand Up @@ -233,6 +270,31 @@ public MySqlClient(
.build());
}

@Override
protected Map<String, CaseSensitivity> getCaseSensitivityForColumns(ConnectorSession session, Connection connection, JdbcTableHandle tableHandle)
{
if (tableHandle.isSynthetic()) {
return ImmutableMap.of();
}
PreparedQuery preparedQuery = new PreparedQuery(format("SELECT * FROM %s", quoted(tableHandle.asPlainTable().getRemoteTableName())), ImmutableList.of());

try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery, Optional.empty())) {
ResultSetMetaData metadata = preparedStatement.getMetaData();
ImmutableMap.Builder<String, CaseSensitivity> columns = ImmutableMap.builder();
for (int column = 1; column <= metadata.getColumnCount(); column++) {
String name = metadata.getColumnName(column);
columns.put(name, metadata.isCaseSensitive(column) ? CASE_SENSITIVE : CASE_INSENSITIVE);
}
return columns.buildOrThrow();
}
catch (SQLException e) {
if (e.getErrorCode() == ER_NO_SUCH_TABLE) {
throw new TableNotFoundException(tableHandle.asPlainTable().getSchemaTableName());
}
throw new TrinoException(JDBC_ERROR, "Failed to get case sensitivity for columns. " + firstNonNull(e.getMessage(), e), e);
}
}

@Override
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
{
Expand Down Expand Up @@ -426,14 +488,14 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
return Optional.of(decimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0))));

case Types.CHAR:
return Optional.of(defaultCharColumnMapping(typeHandle.getRequiredColumnSize(), false));
return Optional.of(mySqlDefaultCharColumnMapping(typeHandle.getRequiredColumnSize(), typeHandle.getCaseSensitivity()));

// TODO not all these type constants are necessarily used by the JDBC driver
case Types.VARCHAR:
case Types.NVARCHAR:
case Types.LONGVARCHAR:
case Types.LONGNVARCHAR:
return Optional.of(defaultVarcharColumnMapping(typeHandle.getRequiredColumnSize(), false));
return Optional.of(mySqlDefaultVarcharColumnMapping(typeHandle.getRequiredColumnSize(), typeHandle.getCaseSensitivity()));

case Types.BINARY:
case Types.VARBINARY:
Expand Down Expand Up @@ -470,6 +532,39 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
return Optional.empty();
}

private static ColumnMapping mySqlDefaultVarcharColumnMapping(int columnSize, Optional<CaseSensitivity> caseSensitivity)
{
if (columnSize > VarcharType.MAX_LENGTH) {
return mySqlVarcharColumnMapping(createUnboundedVarcharType(), caseSensitivity);
}
return mySqlVarcharColumnMapping(createVarcharType(columnSize), caseSensitivity);
}

private static ColumnMapping mySqlVarcharColumnMapping(VarcharType varcharType, Optional<CaseSensitivity> caseSensitivity)
{
PredicatePushdownController pushdownController = caseSensitivity.orElse(CASE_INSENSITIVE) == CASE_SENSITIVE
? MYSQL_CHARACTER_PUSHDOWN
: CASE_INSENSITIVE_CHARACTER_PUSHDOWN;
return ColumnMapping.sliceMapping(varcharType, varcharReadFunction(varcharType), varcharWriteFunction(), pushdownController);
}

private static ColumnMapping mySqlDefaultCharColumnMapping(int columnSize, Optional<CaseSensitivity> caseSensitivity)
{
if (columnSize > CharType.MAX_LENGTH) {
return mySqlDefaultVarcharColumnMapping(columnSize, caseSensitivity);
}
return mySqlCharColumnMapping(createCharType(columnSize), caseSensitivity);
}

private static ColumnMapping mySqlCharColumnMapping(CharType charType, Optional<CaseSensitivity> caseSensitivity)
{
requireNonNull(charType, "charType is null");
PredicatePushdownController pushdownController = caseSensitivity.orElse(CASE_INSENSITIVE) == CASE_SENSITIVE
? MYSQL_CHARACTER_PUSHDOWN
: CASE_INSENSITIVE_CHARACTER_PUSHDOWN;
return ColumnMapping.sliceMapping(charType, charReadFunction(charType), charWriteFunction(), pushdownController);
}

private LongWriteFunction mySqlDateWriteFunctionUsingLocalDate()
{
return new LongWriteFunction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
*/
package io.trino.plugin.mysql;

import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.testing.MaterializedResult;
import io.trino.testing.TestingConnectorBehavior;
import io.trino.testing.sql.SqlExecutor;
import io.trino.testing.sql.TestTable;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.sql.Connection;
Expand All @@ -32,6 +36,8 @@
import static com.google.common.base.Strings.nullToEmpty;
import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.TestingSession.testSessionBuilder;
import static java.lang.String.format;
Expand Down Expand Up @@ -337,6 +343,103 @@ public void testPredicatePushdown()
.isFullyPushedDown();
}

@Test(dataProvider = "charsetAndCollation")
public void testPredicatePushdownWithCollationView(String charset, String collation)
{
onRemoteDatabase().execute(format("CREATE OR REPLACE VIEW tpch.test_view AS SELECT regionkey, nationkey, CONVERT(name USING %s) COLLATE %s AS name FROM tpch.nation;", charset, collation));
testNationCollationQueries("test_view");
onRemoteDatabase().execute("DROP VIEW tpch.test_view");
}

@Test(dataProvider = "charsetAndCollation")
public void testPredicatePushdownWithCollation(String charset, String collation)
{
try (TestTable testTable = new TestTable(
onRemoteDatabase(),
"tpch.nation_collate",
format("AS SELECT regionkey, nationkey, CONVERT(name USING %s) COLLATE %s AS name FROM tpch.nation", charset, collation))) {
testNationCollationQueries(testTable.getName());
}
}

private void testNationCollationQueries(String objectName)
{
// varchar inequality
assertThat(query(format("SELECT regionkey, nationkey, name FROM %s WHERE name != 'ROMANIA' AND name != 'ALGERIA'", objectName)))
.isFullyPushedDown();

// varchar equality
assertThat(query(format("SELECT regionkey, nationkey, name FROM %s WHERE name = 'ROMANIA'", objectName)))
.isFullyPushedDown();

// varchar range
assertThat(query(format("SELECT regionkey, nationkey, name FROM %s WHERE name BETWEEN 'POLAND' AND 'RPA'", objectName)))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(255)))")
// We are not supporting range predicate pushdown for varchars
.isNotFullyPushedDown(FilterNode.class);

// varchar NOT IN
assertThat(query(format("SELECT regionkey, nationkey, name FROM %s WHERE name NOT IN ('POLAND', 'ROMANIA', 'VIETNAM')", objectName)))
.isFullyPushedDown();

// varchar NOT IN with small compaction threshold
assertThat(query(
Session.builder(getSession())
.setCatalogSessionProperty("mysql", "domain_compaction_threshold", "1")
.build(),
format("SELECT regionkey, nationkey, name FROM %s WHERE name NOT IN ('POLAND', 'ROMANIA', 'VIETNAM')", objectName)))
// no pushdown because it was converted to range predicate
.isNotFullyPushedDown(
node(
FilterNode.class,
// verify that no constraint is applied by the connector
tableScan(
tableHandle -> ((JdbcTableHandle) tableHandle).getConstraint().isAll(),
TupleDomain.all(),
ImmutableMap.of())));

// varchar IN without domain compaction
assertThat(query(format("SELECT regionkey, nationkey, name FROM %s WHERE name IN ('POLAND', 'ROMANIA', 'VIETNAM')", objectName)))
.matches("VALUES " +
"(BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(255))), " +
"(BIGINT '2', BIGINT '21', CAST('VIETNAM' AS varchar(255)))")
.isFullyPushedDown();

// varchar IN with small compaction threshold
assertThat(query(
Session.builder(getSession())
.setCatalogSessionProperty("mysql", "domain_compaction_threshold", "1")
.build(),
format("SELECT regionkey, nationkey, name FROM %s WHERE name IN ('POLAND', 'ROMANIA', 'VIETNAM')", objectName)))
.matches("VALUES " +
"(BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(255))), " +
"(BIGINT '2', BIGINT '21', CAST('VIETNAM' AS varchar(255)))")
// no pushdown because it was converted to range predicate
.isNotFullyPushedDown(
node(
FilterNode.class,
// verify that no constraint is applied by the connector
tableScan(
tableHandle -> ((JdbcTableHandle) tableHandle).getConstraint().isAll(),
TupleDomain.all(),
ImmutableMap.of())));
// varchar different case
assertThat(query(format("SELECT regionkey, nationkey, name FROM %s WHERE name = 'romania'", objectName)))
.returnsEmptyResult()
.isFullyPushedDown();

Session joinPushdownEnabled = joinPushdownEnabled(getSession());
// join on varchar columns
assertThat(query(joinPushdownEnabled, format("SELECT n.name, n2.regionkey FROM %1$s n JOIN %1$s n2 ON n.name = n2.name", objectName)))
.joinIsNotFullyPushedDown();
}

@DataProvider
public static Object[][] charsetAndCollation()
{
return new Object[][] {{"latin1", "latin1_general_cs"}, {"utf8", "utf8_bin"}};
}

/**
* This test helps to tune TupleDomain simplification threshold.
*/
Expand Down

0 comments on commit 7e9513e

Please sign in to comment.