diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingConnectorSession.java b/core/trino-main/src/main/java/io/trino/testing/TestingConnectorSession.java index 5af1894eabfa..32855af9e5b3 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingConnectorSession.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingConnectorSession.java @@ -151,10 +151,10 @@ public static Builder builder() public static class Builder { private ConnectorIdentity identity = ConnectorIdentity.ofUser("user"); - private final Optional source = Optional.of("test"); + private Optional source = Optional.of("test"); private TimeZoneKey timeZoneKey = UTC_KEY; private final Locale locale = ENGLISH; - private final Optional traceToken = Optional.empty(); + private Optional traceToken = Optional.empty(); private Optional start = Optional.empty(); private List> propertyMetadatas = ImmutableList.of(); private Map propertyValues = ImmutableMap.of(); @@ -177,6 +177,18 @@ public Builder setStart(Instant start) return this; } + public Builder setSource(String source) + { + this.source = Optional.of(source); + return this; + } + + public Builder setTraceToken(String token) + { + this.traceToken = Optional.of(token); + return this; + } + public Builder setPropertyMetadata(List> propertyMetadatas) { requireNonNull(propertyMetadatas, "propertyMetadatas is null"); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index d19d34291cad..151168859a8a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableSortedSet; import com.google.common.io.Closer; import io.airlift.log.Logger; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; @@ -104,6 +105,7 @@ public abstract class BaseJdbcClient protected final String identifierQuote; protected final Set jdbcTypesMappedToVarchar; private final IdentifierMapping identifierMapping; + private final RemoteQueryModifier queryModifier; private final boolean supportsRetries; @@ -112,7 +114,8 @@ public BaseJdbcClient( String identifierQuote, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, - IdentifierMapping identifierMapping) + IdentifierMapping identifierMapping, + RemoteQueryModifier remoteQueryModifier) { this( identifierQuote, @@ -120,6 +123,7 @@ public BaseJdbcClient( queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, + remoteQueryModifier, false); } @@ -129,6 +133,7 @@ public BaseJdbcClient( QueryBuilder queryBuilder, Set jdbcTypesMappedToVarchar, IdentifierMapping identifierMapping, + RemoteQueryModifier remoteQueryModifier, boolean supportsRetries) { this.identifierQuote = requireNonNull(identifierQuote, "identifierQuote is null"); @@ -138,6 +143,7 @@ public BaseJdbcClient( .addAll(requireNonNull(jdbcTypesMappedToVarchar, "jdbcTypesMappedToVarchar is null")) .build(); this.identifierMapping = requireNonNull(identifierMapping, "identifierMapping is null"); + this.queryModifier = requireNonNull(remoteQueryModifier, "remoteQueryModifier is null"); this.supportsRetries = supportsRetries; } @@ -610,7 +616,7 @@ protected JdbcOutputTableHandle createTable(ConnectorSession session, ConnectorT RemoteTableName remoteTableName = new RemoteTableName(Optional.ofNullable(catalog), Optional.ofNullable(remoteSchema), remoteTargetTableName); String sql = createTableSql(remoteTableName, columnList.build(), tableMetadata); - execute(connection, sql); + execute(session, connection, sql); return new JdbcOutputTableHandle( catalog, @@ -682,7 +688,7 @@ public JdbcOutputTableHandle beginInsertTable(ConnectorSession session, JdbcTabl } String remoteTemporaryTableName = identifierMapping.toRemoteTableName(identity, connection, remoteSchema, generateTemporaryTableName()); - copyTableSchema(connection, catalog, remoteSchema, remoteTable, remoteTemporaryTableName, columnNames.build()); + copyTableSchema(session, connection, catalog, remoteSchema, remoteTable, remoteTemporaryTableName, columnNames.build()); Optional pageSinkIdColumn = Optional.empty(); if (shouldUseFaultTolerantExecution(session)) { @@ -709,7 +715,7 @@ public JdbcOutputTableHandle beginInsertTable(ConnectorSession session, JdbcTabl } } - protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) + protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) { String sql = format( "CREATE TABLE %s AS SELECT %s FROM %s WHERE 0 = 1", @@ -719,7 +725,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String .collect(joining(", ")), quoted(catalogName, schemaName, tableName)); try { - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -774,7 +780,7 @@ protected void renameTable(ConnectorSession session, String catalogName, String protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException { - execute(connection, format( + execute(session, connection, format( "ALTER TABLE %s RENAME TO %s", quoted(catalogName, remoteSchemaName, remoteTableName), quoted(catalogName, newRemoteSchemaName, newRemoteTableName))); @@ -800,9 +806,10 @@ private RemoteTableName constructPageSinkIdsTable(ConnectorSession session, Conn String pageSinkInsertSql = format("INSERT INTO %s (%s) VALUES (?)", quoted(pageSinkTable), pageSinkIdColumnName); + pageSinkInsertSql = queryModifier.apply(session, pageSinkInsertSql); LongWriteFunction pageSinkIdWriter = (LongWriteFunction) toWriteMapping(session, TRINO_PAGE_SINK_ID_COLUMN_TYPE).getWriteFunction(); - execute(connection, pageSinkTableSql); + execute(session, connection, pageSinkTableSql); try (PreparedStatement statement = connection.prepareStatement(pageSinkInsertSql)) { int batchSize = 0; @@ -868,7 +875,7 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha handle.getPageSinkIdColumnName().get()); } - execute(connection, insertSql); + execute(session, connection, insertSql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -919,7 +926,7 @@ private void addColumn(ConnectorSession session, Connection connection, RemoteTa "ALTER TABLE %s ADD %s", quoted(table), getColumnDefinitionSql(session, column, remoteColumnName)); - execute(connection, sql); + execute(session, connection, sql); } @Override @@ -939,7 +946,7 @@ public void renameColumn(ConnectorSession session, JdbcTableHandle handle, JdbcC protected void renameColumn(ConnectorSession session, Connection connection, RemoteTableName remoteTableName, String remoteColumnName, String newRemoteColumnName) throws SQLException { - execute(connection, format( + execute(session, connection, format( "ALTER TABLE %s RENAME COLUMN %s TO %s", quoted(remoteTableName), quoted(remoteColumnName), @@ -956,7 +963,7 @@ public void dropColumn(ConnectorSession session, JdbcTableHandle handle, JdbcCol "ALTER TABLE %s DROP COLUMN %s", quoted(handle.asPlainTable().getRemoteTableName()), quoted(remoteColumnName)); - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -1076,7 +1083,7 @@ public void createSchema(ConnectorSession session, String schemaName) protected void createSchema(ConnectorSession session, Connection connection, String remoteSchemaName) throws SQLException { - execute(connection, "CREATE SCHEMA " + quoted(remoteSchemaName)); + execute(session, connection, "CREATE SCHEMA " + quoted(remoteSchemaName)); } @Override @@ -1096,7 +1103,7 @@ public void dropSchema(ConnectorSession session, String schemaName) protected void dropSchema(ConnectorSession session, Connection connection, String remoteSchemaName) throws SQLException { - execute(connection, "DROP SCHEMA " + quoted(remoteSchemaName)); + execute(session, connection, "DROP SCHEMA " + quoted(remoteSchemaName)); } @Override @@ -1118,25 +1125,26 @@ public void renameSchema(ConnectorSession session, String schemaName, String new protected void renameSchema(ConnectorSession session, Connection connection, String remoteSchemaName, String newRemoteSchemaName) throws SQLException { - execute(connection, "ALTER SCHEMA " + quoted(remoteSchemaName) + " RENAME TO " + quoted(newRemoteSchemaName)); + execute(session, connection, "ALTER SCHEMA " + quoted(remoteSchemaName) + " RENAME TO " + quoted(newRemoteSchemaName)); } protected void execute(ConnectorSession session, String query) { try (Connection connection = connectionFactory.openConnection(session)) { - execute(connection, query); + execute(session, connection, query); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); } } - protected void execute(Connection connection, String query) + protected void execute(ConnectorSession session, Connection connection, String query) throws SQLException { try (Statement statement = connection.createStatement()) { - log.debug("Execute: %s", query); - statement.execute(query); + String modifiedQuery = queryModifier.apply(session, query); + log.debug("Execute: %s", modifiedQuery); + statement.execute(modifiedQuery); } catch (SQLException e) { e.addSuppressed(new RuntimeException("Query: " + query)); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index 39df67f35483..a32cab3386a5 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -16,8 +16,10 @@ import com.google.common.base.Joiner; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; import io.airlift.log.Logger; import io.airlift.slice.Slice; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.JoinType; @@ -42,6 +44,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; public class DefaultQueryBuilder @@ -53,6 +56,19 @@ public class DefaultQueryBuilder private static final String ALWAYS_TRUE = "1=1"; private static final String ALWAYS_FALSE = "1=0"; + private final RemoteQueryModifier queryModifier; + + public DefaultQueryBuilder() + { + this(RemoteQueryModifier.NONE); + } + + @Inject + public DefaultQueryBuilder(RemoteQueryModifier queryModifier) + { + this.queryModifier = requireNonNull(queryModifier, "queryModifier is null"); + } + @Override public PreparedQuery prepareSelectQuery( JdbcClient client, @@ -163,8 +179,9 @@ public PreparedStatement prepareStatement( PreparedQuery preparedQuery) throws SQLException { - log.debug("Preparing query: %s", preparedQuery.getQuery()); - PreparedStatement statement = client.getPreparedStatement(connection, preparedQuery.getQuery()); + String modifiedQuery = queryModifier.apply(session, preparedQuery.getQuery()); + log.debug("Preparing query: %s", modifiedQuery); + PreparedStatement statement = client.getPreparedStatement(connection, modifiedQuery); List parameters = preparedQuery.getParameters(); for (int i = 0; i < parameters.size(); i++) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java index 42a578380160..131c375e90f0 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java @@ -21,6 +21,7 @@ import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.base.CatalogName; import io.trino.plugin.base.session.SessionPropertiesProvider; +import io.trino.plugin.jdbc.logging.RemoteQueryModifierModule; import io.trino.plugin.jdbc.mapping.IdentifierMappingModule; import io.trino.plugin.jdbc.procedure.FlushJdbcMetadataCacheProcedure; import io.trino.spi.connector.ConnectorAccessControl; @@ -49,6 +50,7 @@ public void setup(Binder binder) { install(new JdbcDiagnosticModule()); install(new IdentifierMappingModule()); + install(new RemoteQueryModifierModule()); newOptionalBinder(binder, ConnectorAccessControl.class); newOptionalBinder(binder, QueryBuilder.class).setDefault().to(DefaultQueryBuilder.class).in(Scopes.SINGLETON); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java index 033011a3b5b2..47da5a8658d0 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSink.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; @@ -57,7 +58,7 @@ public class JdbcPageSink private final LongWriteFunction pageSinkIdWriteFunction; private final boolean includePageSinkIdColumn; - public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId) + public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId, RemoteQueryModifier remoteQueryModifier) { try { connection = jdbcClient.getConnection(session, handle); @@ -112,8 +113,12 @@ public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, Jdbc String insertSql = jdbcClient.buildInsertSql(handle, columnWriters); try { + insertSql = remoteQueryModifier.apply(session, insertSql); statement = connection.prepareStatement(insertSql); } + catch (TrinoException e) { + throw closeAllSuppress(e, connection); + } catch (SQLException e) { closeAllSuppress(e, connection); throw new TrinoException(JDBC_ERROR, e); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSinkProvider.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSinkProvider.java index 7be211ce8f4c..9ac0048cfe6d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSinkProvider.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSinkProvider.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.jdbc; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ConnectorInsertTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; @@ -29,22 +30,24 @@ public class JdbcPageSinkProvider implements ConnectorPageSinkProvider { private final JdbcClient jdbcClient; + private RemoteQueryModifier queryModifier; @Inject - public JdbcPageSinkProvider(JdbcClient jdbcClient) + public JdbcPageSinkProvider(JdbcClient jdbcClient, RemoteQueryModifier remoteQueryModifier) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); + this.queryModifier = requireNonNull(remoteQueryModifier, "remoteQueryModifier is null"); } @Override public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle tableHandle, ConnectorPageSinkId pageSinkId) { - return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId); + return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId, queryModifier); } @Override public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorInsertTableHandle tableHandle, ConnectorPageSinkId pageSinkId) { - return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId); + return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId, queryModifier); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/FormatBasedRemoteQueryModifier.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/FormatBasedRemoteQueryModifier.java new file mode 100644 index 000000000000..bfa7af708ccb --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/FormatBasedRemoteQueryModifier.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import com.google.inject.Inject; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; + +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.regex.Pattern; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FormatBasedRemoteQueryModifier + implements RemoteQueryModifier +{ + private final String commentFormat; + + @Inject + public FormatBasedRemoteQueryModifier(FormatBasedRemoteQueryModifierConfig config) + { + this.commentFormat = requireNonNull(config, "config is null").getFormat(); + checkState(!commentFormat.isBlank(), "comment format is blank"); + } + + @Override + public String apply(ConnectorSession session, String query) + { + String message = commentFormat; + for (PredefinedValue predefinedValue : PredefinedValue.values()) { + message = message.replaceAll(predefinedValue.getMatchCase(), predefinedValue.value(session)); + } + return query + " /*" + message + "*/"; + } + + enum PredefinedValue + { + QUERY_ID(ConnectorSession::getQueryId), + SOURCE(new SanitizedValuesProvider(session -> session.getSource().orElse(""), "$SOURCE")), + USER(ConnectorSession::getUser), + TRACE_TOKEN(new SanitizedValuesProvider(session -> session.getTraceToken().orElse(""), "$TRACE_TOKEN")); + + private final Function valueProvider; + + PredefinedValue(Function valueProvider) + { + this.valueProvider = valueProvider; + } + + String getMatchCase() + { + return "\\$" + this.name(); + } + + String getPredefinedValueCode() + { + return "$" + this.name(); + } + + String value(ConnectorSession session) + { + return valueProvider.apply(session); + } + } + + private static class SanitizedValuesProvider + implements Function + { + private static final Predicate VALIDATION_MATCHER = Pattern.compile("^[\\w_]*$").asMatchPredicate(); + private final Function valueProvider; + private final String name; + + private SanitizedValuesProvider(Function valueProvider, String name) + { + this.valueProvider = requireNonNull(valueProvider, "valueProvider is null"); + this.name = requireNonNull(name, "name is null"); + } + + @Override + public String apply(ConnectorSession session) + { + String value = valueProvider.apply(session); + if (VALIDATION_MATCHER.test(value)) { + return value; + } + throw new TrinoException(JDBC_NON_TRANSIENT_ERROR, format("Passed value %s as %s does not meet security criteria. It can contain only letters, digits and underscores", value, name)); + } + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/FormatBasedRemoteQueryModifierConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/FormatBasedRemoteQueryModifierConfig.java new file mode 100644 index 000000000000..15533018c331 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/FormatBasedRemoteQueryModifierConfig.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.trino.plugin.jdbc.logging.FormatBasedRemoteQueryModifier.PredefinedValue; + +import javax.validation.constraints.AssertTrue; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.MatchResult; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.util.stream.Collectors.joining; + +public class FormatBasedRemoteQueryModifierConfig +{ + private static final List PREDEFINED_MATCHES = Arrays.stream(PredefinedValue.values()).map(PredefinedValue::getMatchCase).toList(); + private static final Pattern VALIDATION_PATTERN = Pattern.compile("[\\w ,=]|" + String.join("|", PREDEFINED_MATCHES)); + private String format = ""; + + @Config("query.comment-format") + @ConfigDescription("Format in which logs about query execution context should be added as comments sent through jdbc driver.") + public FormatBasedRemoteQueryModifierConfig setFormat(String format) + { + this.format = format; + return this; + } + + public String getFormat() + { + return format; + } + + @AssertTrue(message = "Incorrect format it may consist of only letters, digits, underscores, commas, spaces, equal signs and predefined values") + boolean isFormatValid() + { + Matcher matcher = VALIDATION_PATTERN.matcher(format); + return matcher.results() + .map(MatchResult::group) + .collect(joining()) + .equals(format); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/RemoteQueryModifier.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/RemoteQueryModifier.java new file mode 100644 index 000000000000..22ebc2378582 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/RemoteQueryModifier.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import io.trino.spi.connector.ConnectorSession; + +public interface RemoteQueryModifier +{ + RemoteQueryModifier NONE = (session, query) -> query; + + String apply(ConnectorSession session, String query); +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/RemoteQueryModifierModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/RemoteQueryModifierModule.java new file mode 100644 index 000000000000..bed2312497b8 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/logging/RemoteQueryModifierModule.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import com.google.inject.Binder; +import io.airlift.configuration.AbstractConfigurationAwareModule; + +import static com.google.inject.Scopes.SINGLETON; +import static io.airlift.configuration.ConditionalModule.conditionalModule; +import static io.airlift.configuration.ConfigBinder.configBinder; + +public class RemoteQueryModifierModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + configBinder(binder).bindConfig(FormatBasedRemoteQueryModifierConfig.class); + install(conditionalModule( + FormatBasedRemoteQueryModifierConfig.class, + config -> config.getFormat().isBlank(), + innerBinder -> innerBinder.bind(RemoteQueryModifier.class).toInstance(RemoteQueryModifier.NONE), + innerBinder -> innerBinder.bind(RemoteQueryModifier.class).to(FormatBasedRemoteQueryModifier.class).in(SINGLETON))); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java index 3a7139d5665c..21d380659abe 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestingH2JdbcClient.java @@ -19,6 +19,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementCountAll; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteVariable; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; @@ -89,7 +90,7 @@ public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFa public TestingH2JdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping) { - super(config, "\"", connectionFactory, new DefaultQueryBuilder(), identifierMapping); + super(config, "\"", connectionFactory, new DefaultQueryBuilder(), identifierMapping, RemoteQueryModifier.NONE); } @Override diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java new file mode 100644 index 000000000000..08db3fa328bd --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifier.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import io.trino.spi.TrinoException; +import io.trino.spi.security.ConnectorIdentity; +import io.trino.testing.TestingConnectorSession; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestFormatBasedRemoteQueryModifier +{ + @Test + public void testCreatingCommentToAppendBasedOnFormatAndConnectorSession() + { + TestingConnectorSession connectorSession = TestingConnectorSession.builder() + .setTraceToken("trace_token") + .setSource("source") + .setIdentity(ConnectorIdentity.ofUser("Alice")) + .build(); + + FormatBasedRemoteQueryModifier modifier = createRemoteQueryModifier("Query=$QUERY_ID Execution for user=$USER with source=$SOURCE ttoken=$TRACE_TOKEN"); + String modifiedQuery = modifier.apply(connectorSession, "SELECT * from USERS"); + + assertThat(modifiedQuery) + .isEqualTo("SELECT * from USERS /*Query=%s Execution for user=%s with source=%s ttoken=%s*/", connectorSession.getQueryId(), "Alice", "source", "trace_token"); + } + + @Test + public void testCreatingCommentWithDuplicatedPredefinedValues() + { + TestingConnectorSession connectorSession = TestingConnectorSession.builder() + .setTraceToken("trace_token") + .setSource("source") + .setIdentity(ConnectorIdentity.ofUser("Alice")) + .build(); + + FormatBasedRemoteQueryModifier modifier = createRemoteQueryModifier("$QUERY_ID, $QUERY_ID, $QUERY_ID, $QUERY_ID, $USER, $USER, $SOURCE, $SOURCE, $SOURCE, $TRACE_TOKEN, $TRACE_TOKEN"); + String modifiedQuery = modifier.apply(connectorSession, "SELECT * from USERS"); + + assertThat(modifiedQuery) + .isEqualTo("SELECT * from USERS /*%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s*/", + connectorSession.getQueryId(), + connectorSession.getQueryId(), + connectorSession.getQueryId(), + connectorSession.getQueryId(), + "Alice", + "Alice", + "source", + "source", + "source", + "trace_token", + "trace_token"); + } + + @Test + public void testForSQLInjectionsByTraceToken() + { + TestingConnectorSession connectorSession = TestingConnectorSession.builder() + .setTraceToken("*/; DROP TABLE TABLE_A; /*") + .setSource("source") + .setIdentity(ConnectorIdentity.ofUser("Alice")) + .build(); + + FormatBasedRemoteQueryModifier modifier = createRemoteQueryModifier("Query=$QUERY_ID Execution for user=$USER with source=$SOURCE ttoken=$TRACE_TOKEN"); + + assertThatThrownBy(() -> modifier.apply(connectorSession, "SELECT * from USERS")) + .isInstanceOf(TrinoException.class) + .hasMessage("Passed value */; DROP TABLE TABLE_A; /* as $TRACE_TOKEN does not meet security criteria. It can contain only letters, digits and underscores"); + } + + @Test + public void testForSQLInjectionsBySource() + { + TestingConnectorSession connectorSession = TestingConnectorSession.builder() + .setTraceToken("trace_token") + .setSource("*/; DROP TABLE TABLE_A; /*") + .setIdentity(ConnectorIdentity.ofUser("Alice")) + .build(); + + FormatBasedRemoteQueryModifier modifier = createRemoteQueryModifier("Query=$QUERY_ID Execution for user=$USER with source=$SOURCE ttoken=$TRACE_TOKEN"); + + assertThatThrownBy(() -> modifier.apply(connectorSession, "SELECT * from USERS")) + .isInstanceOf(TrinoException.class) + .hasMessage("Passed value */; DROP TABLE TABLE_A; /* as $SOURCE does not meet security criteria. It can contain only letters, digits and underscores"); + } + + @Test + public void testFormatWithEmptyValues() + { + TestingConnectorSession connectorSession = TestingConnectorSession.builder() + .setIdentity(ConnectorIdentity.ofUser("Alice")) + .setSource("") + .build(); + + FormatBasedRemoteQueryModifier modifier = createRemoteQueryModifier("source=$SOURCE ttoken=$TRACE_TOKEN"); + + String modifiedQuery = modifier.apply(connectorSession, "SELECT * FROM USERS"); + + assertThat(modifiedQuery) + .isEqualTo("SELECT * FROM USERS /*source= ttoken=*/"); + } + + private static FormatBasedRemoteQueryModifier createRemoteQueryModifier(String commentFormat) + { + return new FormatBasedRemoteQueryModifier(new FormatBasedRemoteQueryModifierConfig().setFormat(commentFormat)); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java new file mode 100644 index 000000000000..a261fd92e39e --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierConfig.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.util.Arrays.array; + +public class TestFormatBasedRemoteQueryModifierConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(FormatBasedRemoteQueryModifierConfig.class).setFormat("")); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder().put("query.comment-format", "format").buildOrThrow(); + + FormatBasedRemoteQueryModifierConfig expected = new FormatBasedRemoteQueryModifierConfig().setFormat("format"); + + assertFullMapping(properties, expected); + } + + @Test(dataProvider = "getForbiddenValuesInFormat") + public void testInvalidFormatValue(String incorrectValue) + { + assertThat(new FormatBasedRemoteQueryModifierConfig().setFormat(incorrectValue).isFormatValid()) + .isFalse(); + } + + @DataProvider + public static Object[][] getForbiddenValuesInFormat() + { + return array( + array("*"), + array("("), + array(")"), + array("["), + array("]"), + array("{"), + array("}"), + array("&"), + array("@"), + array("!"), + array("#"), + array("%"), + array("^"), + array("$"), + array("\\"), + array("/"), + array("?"), + array(">"), + array("<"), + array(";"), + array("\""), + array(":"), + array("|")); + } + + @Test + public void testValidFormatWithPredefinedValues() + { + assertThat(new FormatBasedRemoteQueryModifierConfig().setFormat("$QUERY_ID $USER $SOURCE $TRACE_TOKEN").isFormatValid()).isTrue(); + } + + @Test + public void testValidFormatWithDuplicatedPredefinedValues() + { + assertThat(new FormatBasedRemoteQueryModifierConfig().setFormat("$QUERY_ID $QUERY_ID $USER $USER $SOURCE $SOURCE $TRACE_TOKEN $TRACE_TOKEN").isFormatValid()).isTrue(); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierModule.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierModule.java new file mode 100644 index 000000000000..b1035b7ae9ce --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/logging/TestFormatBasedRemoteQueryModifierModule.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.logging; + +import com.google.common.collect.ImmutableMap; +import io.airlift.bootstrap.Bootstrap; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class TestFormatBasedRemoteQueryModifierModule +{ + @Test + public void testRemoteQueryModifierAvailableByDefault() + { + RemoteQueryModifier remoteQueryModifier = new Bootstrap(new RemoteQueryModifierModule()) + .initialize() + .getInstance(RemoteQueryModifier.class); + + assertThat(remoteQueryModifier) + .isEqualTo(RemoteQueryModifier.NONE); + } + + @Test + public void testNonEmptyFormatProducingNonDefaultRemoteQueryModifier() + { + RemoteQueryModifier remoteQueryModifier = new Bootstrap(new RemoteQueryModifierModule()) + .setRequiredConfigurationProperties(ImmutableMap.of("query.comment-format", "valid format")) + .initialize() + .getInstance(RemoteQueryModifier.class); + + assertThat(remoteQueryModifier) + .isNotEqualTo(RemoteQueryModifier.NONE); + } +} diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index 9952a4716c82..65bf0579df22 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -48,6 +48,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -204,9 +205,10 @@ public ClickHouseClient( ConnectionFactory connectionFactory, QueryBuilder queryBuilder, TypeManager typeManager, - IdentifierMapping identifierMapping) + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { - super(config, "\"", connectionFactory, queryBuilder, identifierMapping); + super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID)); this.ipAddressType = typeManager.getType(new TypeSignature(StandardTypes.IPADDRESS)); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); @@ -256,7 +258,7 @@ protected String quoted(@Nullable String catalog, @Nullable String schema, Strin } @Override - protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) + protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) { // ClickHouse does not support `create table tbl as select * from tbl2 where 0=1` // ClickHouse supports the following two methods to copy schema @@ -267,7 +269,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String quoted(null, schemaName, newTableName), quoted(null, schemaName, tableName)); try { - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -368,7 +370,7 @@ public void setTableProperties(ConnectorSession session, JdbcTableHandle handle, "ALTER TABLE %s MODIFY %s", quoted(handle.asPlainTable().getRemoteTableName()), join(" ", tableOptions.build())); - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -399,21 +401,21 @@ protected String getColumnDefinitionSql(ConnectorSession session, ColumnMetadata protected void createSchema(ConnectorSession session, Connection connection, String remoteSchemaName) throws SQLException { - execute(connection, "CREATE DATABASE " + quoted(remoteSchemaName)); + execute(session, connection, "CREATE DATABASE " + quoted(remoteSchemaName)); } @Override protected void dropSchema(ConnectorSession session, Connection connection, String remoteSchemaName) throws SQLException { - execute(connection, "DROP DATABASE " + quoted(remoteSchemaName)); + execute(session, connection, "DROP DATABASE " + quoted(remoteSchemaName)); } @Override protected void renameSchema(ConnectorSession session, Connection connection, String remoteSchemaName, String newRemoteSchemaName) throws SQLException { - execute(connection, "RENAME DATABASE " + quoted(remoteSchemaName) + " TO " + quoted(newRemoteSchemaName)); + execute(session, connection, "RENAME DATABASE " + quoted(remoteSchemaName) + " TO " + quoted(newRemoteSchemaName)); } @Override @@ -425,7 +427,7 @@ public void addColumn(ConnectorSession session, JdbcTableHandle handle, ColumnMe "ALTER TABLE %s ADD COLUMN %s", quoted(handle.asPlainTable().getRemoteTableName()), getColumnDefinitionSql(session, column, remoteColumnName)); - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -469,7 +471,7 @@ protected Optional> getTableTypes() protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException { - execute(connection, format("RENAME TABLE %s.%s TO %s.%s", + execute(session, connection, format("RENAME TABLE %s.%s TO %s.%s", quoted(remoteSchemaName), quoted(remoteTableName), quoted(newRemoteSchemaName), diff --git a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java index a9fd45c9cfc3..552a6b16eeb2 100644 --- a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java +++ b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java @@ -30,6 +30,7 @@ import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnMetadata; @@ -150,9 +151,9 @@ public class DruidJdbcClient .withChronology(IsoChronology.INSTANCE); @Inject - public DruidJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) + public DruidJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) { - super(config, "\"", connectionFactory, queryBuilder, identifierMapping); + super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); } @Override diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java index 8fb98c5df11e..4e34ff0d2bce 100644 --- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java +++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java @@ -43,6 +43,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -157,9 +158,9 @@ public class MariaDbClient private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject - public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) + public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) { - super(config, "`", connectionFactory, queryBuilder, identifierMapping); + super(config, "`", connectionFactory, queryBuilder, identifierMapping, queryModifier); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() @@ -454,7 +455,7 @@ protected void renameColumn(ConnectorSession session, Connection connection, Rem quoted(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()), quoted(remoteColumnName), quoted(newRemoteColumnName)); - execute(connection, sql); + execute(session, connection, sql); } catch (SQLSyntaxErrorException syntaxError) { // Note: SQLSyntaxErrorException can be thrown also when column name is invalid @@ -466,7 +467,7 @@ protected void renameColumn(ConnectorSession session, Connection connection, Rem } @Override - protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) + protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) { // Copy all columns for enforcing NOT NULL option in the temp table String tableCopyFormat = "CREATE TABLE %s AS SELECT * FROM %s WHERE 0 = 1"; @@ -475,7 +476,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String quoted(catalogName, schemaName, newTableName), quoted(catalogName, schemaName, tableName)); try { - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java index 46965eaaffe7..858ac71d8e42 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java @@ -20,6 +20,7 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -62,7 +63,8 @@ public class TestMariaDbClient throw new UnsupportedOperationException(); }, new DefaultQueryBuilder(), - new DefaultIdentifierMapping()); + new DefaultIdentifierMapping(), + RemoteQueryModifier.NONE); @Test public void testImplementCount() diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index dc1c32e80bb7..c72baf32a4a6 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -49,6 +49,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -204,9 +205,10 @@ public MySqlClient( ConnectionFactory connectionFactory, QueryBuilder queryBuilder, TypeManager typeManager, - IdentifierMapping identifierMapping) + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { - super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, true); + super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, true); this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON)); this.statisticsEnabled = statisticsConfig.isEnabled(); @@ -589,7 +591,7 @@ protected void renameColumn(ConnectorSession session, Connection connection, Rem quoted(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()), quoted(remoteColumnName), quoted(newRemoteColumnName)); - execute(connection, sql); + execute(session, connection, sql); } @Override @@ -599,7 +601,7 @@ public void renameSchema(ConnectorSession session, String schemaName, String new } @Override - protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) + protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) { String tableCopyFormat = "CREATE TABLE %s AS SELECT * FROM %s WHERE 0 = 1"; if (isGtidMode(connection)) { @@ -610,7 +612,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String quoted(catalogName, schemaName, newTableName), quoted(catalogName, schemaName, tableName)); try { - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java index f271bbf9fdc2..820655e5a5ba 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java @@ -21,6 +21,7 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -66,7 +67,8 @@ public class TestMySqlClient }, new DefaultQueryBuilder(), TESTING_TYPE_MANAGER, - new DefaultIdentifierMapping()); + new DefaultIdentifierMapping(), + RemoteQueryModifier.NONE); @Test public void testImplementCount() diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index a5c52ad7dcc4..48c0ca09470b 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -49,6 +49,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -206,9 +207,10 @@ public OracleClient( OracleConfig oracleConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, - IdentifierMapping identifierMapping) + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { - super(config, "\"", connectionFactory, queryBuilder, identifierMapping); + super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); this.synonymsEnabled = oracleConfig.isSynonymsEnabled(); @@ -279,7 +281,7 @@ protected void renameTable(ConnectorSession session, Connection connection, Stri } String newTableName = newRemoteTableName.toUpperCase(ENGLISH); - execute(connection, format( + execute(session, connection, format( "ALTER TABLE %s RENAME TO %s", quoted(catalogName, remoteSchemaName, remoteTableName), quoted(newTableName))); diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleClient.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleClient.java index cc547c7d5f0e..944943185a89 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleClient.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/TestOracleClient.java @@ -18,6 +18,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; @@ -58,7 +59,8 @@ public class TestOracleClient throw new UnsupportedOperationException(); }, new DefaultQueryBuilder(), - new DefaultIdentifierMapping()); + new DefaultIdentifierMapping(), + RemoteQueryModifier.NONE); private static final ConnectorSession SESSION = TestingConnectorSession.SESSION; diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java index 1ede65dbc939..92f828561e7b 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java @@ -35,6 +35,7 @@ import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; @@ -212,7 +213,7 @@ public class PhoenixClient private final Configuration configuration; @Inject - public PhoenixClient(PhoenixConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) + public PhoenixClient(PhoenixConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) throws SQLException { super( @@ -221,6 +222,7 @@ public PhoenixClient(PhoenixConfig config, ConnectionFactory connectionFactory, queryBuilder, ImmutableSet.of(), identifierMapping, + queryModifier, false); this.configuration = newEmptyConfiguration(); getConnectionProperties(config).forEach((k, v) -> configuration.set((String) k, (String) v)); diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java index 71e0148adfe6..d660152b00ac 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClientModule.java @@ -53,6 +53,7 @@ import io.trino.plugin.jdbc.TypeHandlingJdbcConfig; import io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties; import io.trino.plugin.jdbc.credential.EmptyCredentialProvider; +import io.trino.plugin.jdbc.logging.RemoteQueryModifierModule; import io.trino.plugin.jdbc.mapping.IdentifierMappingModule; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorMetadata; @@ -97,6 +98,7 @@ public PhoenixClientModule(String catalogName) @Override protected void setup(Binder binder) { + install(new RemoteQueryModifierModule()); binder.bind(ConnectorSplitManager.class).annotatedWith(ForJdbcDynamicFiltering.class).to(PhoenixSplitManager.class).in(Scopes.SINGLETON); binder.bind(ConnectorSplitManager.class).annotatedWith(ForClassLoaderSafe.class).to(JdbcDynamicFilteringSplitManager.class).in(Scopes.SINGLETON); binder.bind(ConnectorSplitManager.class).to(ClassLoaderSafeConnectorSplitManager.class).in(Scopes.SINGLETON); diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java index 118d02dbb8a2..06f8d666bde2 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.postgresql; +import com.google.inject.Inject; import io.trino.plugin.jdbc.DefaultQueryBuilder; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; @@ -20,6 +21,7 @@ import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.QueryParameter; import io.trino.plugin.jdbc.WriteFunction; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; @@ -34,6 +36,12 @@ public class CollationAwareQueryBuilder extends DefaultQueryBuilder { + @Inject + public CollationAwareQueryBuilder(RemoteQueryModifier queryModifier) + { + super(queryModifier); + } + @Override protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition) { diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 44759a24fd8f..e6c1e2c2e93d 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -67,6 +67,7 @@ import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.expression.RewriteIn; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; import io.trino.spi.TrinoException; @@ -282,9 +283,10 @@ public PostgreSqlClient( ConnectionFactory connectionFactory, QueryBuilder queryBuilder, TypeManager typeManager, - IdentifierMapping identifierMapping) + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { - super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, true); + super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, true); this.jsonType = typeManager.getType(new TypeSignature(JSON)); this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID)); this.varcharMapType = (MapType) typeManager.getType(mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature())); @@ -369,7 +371,7 @@ protected void renameTable(ConnectorSession session, Connection connection, Stri throw new TrinoException(NOT_SUPPORTED, "This connector does not support renaming tables across schemas"); } - execute(connection, format( + execute(session, connection, format( "ALTER TABLE %s RENAME TO %s", quoted(catalogName, remoteSchemaName, remoteTableName), quoted(newRemoteTableName))); diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index f25bff7155ce..fe7f5c0f1613 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -24,6 +24,7 @@ import io.trino.plugin.jdbc.JdbcMetadataSessionProperties; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -112,7 +113,8 @@ public class TestPostgreSqlClient session -> { throw new UnsupportedOperationException(); }, new DefaultQueryBuilder(), TESTING_TYPE_MANAGER, - new DefaultIdentifierMapping()); + new DefaultIdentifierMapping(), + RemoteQueryModifier.NONE); private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java new file mode 100644 index 000000000000..a78a3cad2517 --- /dev/null +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.postgresql; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static io.trino.tpch.TpchTable.NATION; +import static org.assertj.core.api.Assertions.assertThat; + +@Test(singleThreaded = true) +public class TestRemoteQueryCommentLogging + extends AbstractTestQueryFramework +{ + private TestingPostgreSqlServer postgreSqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + postgreSqlServer = closeAfterClass(new TestingPostgreSqlServer()); + DistributedQueryRunner distributedQueryRunner = createPostgreSqlQueryRunner( + postgreSqlServer, + ImmutableMap.of(), + ImmutableMap.of("query.comment-format", "query executed by $USER"), + ImmutableList.of(CUSTOMER, NATION)); + + return distributedQueryRunner; + } + + @Test + public void testShouldLogContextInComment() + { + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("CREATE TABLE postgresql.tpch.log_nation_test_table AS (SELECT * FROM postgresql.tpch.nation)")) + .stopEventsRecording() + .streamQueriesContaining("\"tpch\".\"tpch\".\"tmp_trino_")) + .allMatch(query -> query.endsWith("/*query executed by user*/")) + .size() + .isGreaterThanOrEqualTo(3); //Depending on whether fault tolerancy is enabled or not, this might vary and we don't want to over-specify + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("SELECT * FROM postgresql.tpch.log_nation_test_table")) + .stopEventsRecording() + .streamQueriesContaining("log_nation_test_table")) + .allMatch(query -> query.endsWith("/*query executed by user*/")) + .size() + .isEqualTo(1); + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("DELETE FROM postgresql.tpch.log_nation_test_table")) + .stopEventsRecording() + .streamQueriesContaining("log_nation_test_table")) + .allMatch(query -> query.endsWith("/*query executed by user*/")) + .size() + .isEqualTo(1); + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("INSERT INTO postgresql.tpch.log_nation_test_table VALUES (1, 'nation', 1, 'nation')")) + .stopEventsRecording() + .streamQueriesContaining("log_nation_test_table", "\"tpch\".\"tpch\".\"tmp_trino_")) + .allMatch(query -> query.endsWith("/*query executed by user*/")) + .size() + .isGreaterThanOrEqualTo(1); //Depending on whether fault tolerancy is enabled or not, this might vary and we don't want to over-specify + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("DROP TABLE postgresql.tpch.log_nation_test_table")) + .stopEventsRecording() + .streamQueriesContaining("log_nation_test_table")) + .allMatch(query -> query.endsWith("/*query executed by user*/")) + .size() + .isEqualTo(1); + } + + @Test + public void testShouldLogContextInCommentForTableFunctionsQueryPassthrough() + { + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("SELECT * FROM TABLE( postgresql.system.query(query => 'SELECT name FROM tpch.nation WHERE nationkey = 0'))")) + .stopEventsRecording() + .streamQueriesContaining("tpch.nation")) + .allMatch(query -> query.contains("SELECT name FROM tpch.nation WHERE nationkey = 0")) + .allMatch(query -> query.endsWith("/*query executed by user*/")) + .size() + .isEqualTo(1); + } +} diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLoggingDisabledByDefault.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLoggingDisabledByDefault.java new file mode 100644 index 000000000000..0fb0ec7a5eb5 --- /dev/null +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLoggingDisabledByDefault.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.postgresql; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static io.trino.tpch.TpchTable.NATION; +import static org.assertj.core.api.Assertions.assertThat; + +@Test(singleThreaded = true) +public class TestRemoteQueryCommentLoggingDisabledByDefault + extends AbstractTestQueryFramework +{ + private TestingPostgreSqlServer postgreSqlServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + postgreSqlServer = closeAfterClass(new TestingPostgreSqlServer()); + DistributedQueryRunner distributedQueryRunner = createPostgreSqlQueryRunner( + postgreSqlServer, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableList.of(CUSTOMER, NATION)); + + return distributedQueryRunner; + } + + @Test + public void testShouldNotLogContextInComments() + { + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("CREATE TABLE postgresql.tpch.log_nation_test_table AS (SELECT * FROM postgresql.tpch.nation)")) + .stopEventsRecording() + .streamQueriesContaining("*/")) + .size() + .isEqualTo(0); + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("SELECT * FROM postgresql.tpch.log_nation_test_table")) + .stopEventsRecording() + .streamQueriesContaining("*/")) + .size() + .isEqualTo(0); + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("DELETE FROM postgresql.tpch.log_nation_test_table")) + .stopEventsRecording() + .streamQueriesContaining("*/")) + .size() + .isEqualTo(0); + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("INSERT INTO postgresql.tpch.log_nation_test_table VALUES (1, 'nation', 1, 'nation')")) + .stopEventsRecording() + .streamQueriesContaining("*/")) + .size() + .isGreaterThanOrEqualTo(0); //Depending on whether fault tolerancy is enabled or not, this might vary and we don't want to over-specify + + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("DROP TABLE postgresql.tpch.log_nation_test_table")) + .stopEventsRecording() + .streamQueriesContaining("*/")) + .size() + .isEqualTo(0); + } + + @Test + public void testShouldNotLogContextInCommentForTableFunctionsQueryPassthrough() + { + assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("SELECT * FROM TABLE( postgresql.system.query(query => 'SELECT name FROM tpch.nation WHERE nationkey = 0'))")) + .stopEventsRecording() + .streamQueriesContaining("*/")) + .size() + .isEqualTo(0); + } +} diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java index 5096f93c3a7b..28ab9168a7dc 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestingPostgreSqlServer.java @@ -26,6 +26,9 @@ import java.util.Iterator; import java.util.List; import java.util.Properties; +import java.util.function.Supplier; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Stream; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -33,6 +36,7 @@ import static io.trino.plugin.jdbc.RemoteDatabaseEvent.Status.RUNNING; import static io.trino.testing.containers.TestContainers.exposeFixedPorts; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static java.util.function.Predicate.not; import static org.testcontainers.containers.PostgreSQLContainer.POSTGRESQL_PORT; @@ -44,8 +48,10 @@ public class TestingPostgreSqlServer private static final String DATABASE = "tpch"; private static final String LOG_PREFIX_REGEXP = "^([-:0-9. ]+UTC \\[[0-9]+\\] )"; - private static final String LOG_RUNNING_STATEMENT_PREFIX = "LOG: execute : "; + private static final String LOG_RUNNING_STATEMENT_PREFIX = "LOG: execute "; private static final String LOG_CANCELLATION_EVENT = "ERROR: canceling statement due to user request"; + + private static final Pattern SQL_QUERY_FIND_PATTERN = Pattern.compile("^(: |/C_\\d: )(.*)"); //In PgSQL cursor queries and non-cursor queries are logged differently private static final String LOG_CANCELLED_STATEMENT_PREFIX = "STATEMENT: "; private final PostgreSQLContainer dockerContainer; @@ -88,6 +94,13 @@ private static void execute(String url, Properties properties, String sql) } } + DatabaseEventsRecorder recordEventsForOperations(Runnable operation) + { + DatabaseEventsRecorder events = DatabaseEventsRecorder.startRecording(this); + operation.run(); + return events; + } + protected List getRemoteDatabaseEvents() { List logs = getLogs(); @@ -96,7 +109,11 @@ protected List getRemoteDatabaseEvents() while (logsIterator.hasNext()) { String logLine = logsIterator.next().replaceAll(LOG_PREFIX_REGEXP, ""); if (logLine.startsWith(LOG_RUNNING_STATEMENT_PREFIX)) { - events.add(new RemoteDatabaseEvent(logLine.substring(LOG_RUNNING_STATEMENT_PREFIX.length()), RUNNING)); + Matcher matcher = SQL_QUERY_FIND_PATTERN.matcher(logLine.substring(LOG_RUNNING_STATEMENT_PREFIX.length())); + if (matcher.find()) { + String sqlStatement = matcher.group(2); + events.add(new RemoteDatabaseEvent(sqlStatement, RUNNING)); + } } if (logLine.equals(LOG_CANCELLATION_EVENT)) { // next line must be present @@ -146,4 +163,39 @@ public void close() { dockerContainer.close(); } + + public static class DatabaseEventsRecorder + { + private final Supplier> loggedQueriesSource; + + private DatabaseEventsRecorder(Supplier> loggedQueriesSource) + { + this.loggedQueriesSource = requireNonNull(loggedQueriesSource, "loggedQueriesSource is null"); + } + + static DatabaseEventsRecorder startRecording(TestingPostgreSqlServer server) + { + int startingEventsCount = server.getRemoteDatabaseEvents().size(); + return new DatabaseEventsRecorder(() -> + server.getRemoteDatabaseEvents().stream() + .skip(startingEventsCount) + .map(RemoteDatabaseEvent::getQuery)); + } + + public DatabaseEventsRecorder stopEventsRecording() + { + List queries = loggedQueriesSource.get().collect(toImmutableList()); + return new DatabaseEventsRecorder(queries::stream); + } + + public Stream streamQueriesContaining(String queryPart, String... alternativeQueryParts) + { + ImmutableList queryParts = ImmutableList.builder() + .add(queryPart) + .addAll(ImmutableList.copyOf(alternativeQueryParts)) + .build(); + return loggedQueriesSource.get() + .filter(query -> queryParts.stream().anyMatch(query::contains)); + } + } } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index e5e361753f85..6496c0b678ad 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; @@ -90,9 +91,9 @@ public class RedshiftClient extends BaseJdbcClient { @Inject - public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping) + public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) { - super(config, "\"", connectionFactory, queryBuilder, identifierMapping); + super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); } @Override @@ -110,7 +111,7 @@ protected void renameTable(ConnectorSession session, Connection connection, Stri throw new TrinoException(NOT_SUPPORTED, "This connector does not support renaming tables across schemas"); } - execute(connection, format( + execute(session, connection, format( "ALTER TABLE %s RENAME TO %s", quoted(catalogName, remoteSchemaName, remoteTableName), quoted(newRemoteTableName))); diff --git a/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java b/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java index bde1fc93d0fe..b158d67cc25b 100644 --- a/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java +++ b/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java @@ -29,6 +29,7 @@ import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -152,9 +153,9 @@ public class SingleStoreClient private final Type jsonType; @Inject - public SingleStoreClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, TypeManager typeManager, IdentifierMapping identifierMapping) + public SingleStoreClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, TypeManager typeManager, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) { - super(config, "`", connectionFactory, queryBuilder, identifierMapping); + super(config, "`", connectionFactory, queryBuilder, identifierMapping, queryModifier); requireNonNull(typeManager, "typeManager is null"); this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON)); } @@ -340,7 +341,7 @@ protected void renameColumn(ConnectorSession session, Connection connection, Rem throws SQLException { // SingleStore versions earlier than 5.7 do not support the CHANGE syntax - execute(connection, format( + execute(session, connection, format( "ALTER TABLE %s CHANGE %s %s", quoted(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()), quoted(remoteColumnName), diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 688104c9369f..9647c8436afc 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -53,6 +53,7 @@ import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.expression.RewriteIn; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -212,9 +213,10 @@ public SqlServerClient( JdbcStatisticsConfig statisticsConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, - IdentifierMapping identifierMapping) + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { - super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, true); + super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, true); this.statisticsEnabled = statisticsConfig.isEnabled(); @@ -284,7 +286,7 @@ protected void enableTableLockOnBulkLoadTableOption(ConnectorSession session, Jd // note: this is not a request to lock a table immediately String sql = format("EXEC sp_tableoption '%s', 'table lock on bulk load', '1'", quoted(table.getCatalogName(), table.getSchemaName(), table.getTemporaryTableName().orElseGet(table::getTableName))); - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); @@ -375,7 +377,7 @@ public void renameSchema(ConnectorSession session, String schemaName, String new } @Override - protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) + protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List columnNames) { String sql = format( "SELECT %s INTO %s FROM %s WHERE 0 = 1", @@ -385,7 +387,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String quoted(catalogName, schemaName, newTableName), quoted(catalogName, schemaName, tableName)); try { - execute(connection, sql); + execute(session, connection, sql); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java index be18dc5faeff..c1d2846f291f 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java @@ -21,6 +21,7 @@ import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; @@ -64,7 +65,8 @@ public class TestSqlServerClient throw new UnsupportedOperationException(); }, new DefaultQueryBuilder(), - new DefaultIdentifierMapping()); + new DefaultIdentifierMapping(), + RemoteQueryModifier.NONE); @Test public void testImplementCount()