Skip to content

Commit

Permalink
Add external query trino logs to sql based connectors
Browse files Browse the repository at this point in the history
This change allows for sending additional logs, that are
send in comments to external systems (jdbc only atm). It allows
to pass additional security and observability information to these
systems, that wouldn't be possible to send otherwie.

This implementation covers SELECT, DDL and DML queries, although it
does not cover any metadata queries - as these are made by the drivers
itself and we can't easily intercept these.
  • Loading branch information
s2lomon committed Nov 23, 2022
1 parent 10d52dc commit e2771a8
Show file tree
Hide file tree
Showing 34 changed files with 880 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ public static Builder builder()
public static class Builder
{
private ConnectorIdentity identity = ConnectorIdentity.ofUser("user");
private final Optional<String> source = Optional.of("test");
private Optional<String> source = Optional.of("test");
private TimeZoneKey timeZoneKey = UTC_KEY;
private final Locale locale = ENGLISH;
private final Optional<String> traceToken = Optional.empty();
private Optional<String> traceToken = Optional.empty();
private Optional<Instant> start = Optional.empty();
private List<PropertyMetadata<?>> propertyMetadatas = ImmutableList.of();
private Map<String, Object> propertyValues = ImmutableMap.of();
Expand All @@ -177,6 +177,18 @@ public Builder setStart(Instant start)
return this;
}

public Builder setSource(String source)
{
this.source = Optional.of(source);
return this;
}

public Builder setTraceToken(String token)
{
this.traceToken = Optional.of(token);
return this;
}

public Builder setPropertyMetadata(List<PropertyMetadata<?>> propertyMetadatas)
{
requireNonNull(propertyMetadatas, "propertyMetadatas is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.io.Closer;
import io.airlift.log.Logger;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
Expand Down Expand Up @@ -104,6 +105,7 @@ public abstract class BaseJdbcClient
protected final String identifierQuote;
protected final Set<String> jdbcTypesMappedToVarchar;
private final IdentifierMapping identifierMapping;
private final RemoteQueryModifier queryModifier;

private final boolean supportsRetries;

Expand All @@ -112,14 +114,16 @@ public BaseJdbcClient(
String identifierQuote,
ConnectionFactory connectionFactory,
QueryBuilder queryBuilder,
IdentifierMapping identifierMapping)
IdentifierMapping identifierMapping,
RemoteQueryModifier remoteQueryModifier)
{
this(
identifierQuote,
connectionFactory,
queryBuilder,
config.getJdbcTypesMappedToVarchar(),
identifierMapping,
remoteQueryModifier,
false);
}

Expand All @@ -129,6 +133,7 @@ public BaseJdbcClient(
QueryBuilder queryBuilder,
Set<String> jdbcTypesMappedToVarchar,
IdentifierMapping identifierMapping,
RemoteQueryModifier remoteQueryModifier,
boolean supportsRetries)
{
this.identifierQuote = requireNonNull(identifierQuote, "identifierQuote is null");
Expand All @@ -138,6 +143,7 @@ public BaseJdbcClient(
.addAll(requireNonNull(jdbcTypesMappedToVarchar, "jdbcTypesMappedToVarchar is null"))
.build();
this.identifierMapping = requireNonNull(identifierMapping, "identifierMapping is null");
this.queryModifier = requireNonNull(remoteQueryModifier, "remoteQueryModifier is null");
this.supportsRetries = supportsRetries;
}

Expand Down Expand Up @@ -610,7 +616,7 @@ protected JdbcOutputTableHandle createTable(ConnectorSession session, ConnectorT

RemoteTableName remoteTableName = new RemoteTableName(Optional.ofNullable(catalog), Optional.ofNullable(remoteSchema), remoteTargetTableName);
String sql = createTableSql(remoteTableName, columnList.build(), tableMetadata);
execute(connection, sql);
execute(session, connection, sql);

return new JdbcOutputTableHandle(
catalog,
Expand Down Expand Up @@ -682,7 +688,7 @@ public JdbcOutputTableHandle beginInsertTable(ConnectorSession session, JdbcTabl
}

String remoteTemporaryTableName = identifierMapping.toRemoteTableName(identity, connection, remoteSchema, generateTemporaryTableName());
copyTableSchema(connection, catalog, remoteSchema, remoteTable, remoteTemporaryTableName, columnNames.build());
copyTableSchema(session, connection, catalog, remoteSchema, remoteTable, remoteTemporaryTableName, columnNames.build());

Optional<ColumnMetadata> pageSinkIdColumn = Optional.empty();
if (shouldUseFaultTolerantExecution(session)) {
Expand All @@ -709,7 +715,7 @@ public JdbcOutputTableHandle beginInsertTable(ConnectorSession session, JdbcTabl
}
}

protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List<String> columnNames)
protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List<String> columnNames)
{
String sql = format(
"CREATE TABLE %s AS SELECT %s FROM %s WHERE 0 = 1",
Expand All @@ -719,7 +725,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String
.collect(joining(", ")),
quoted(catalogName, schemaName, tableName));
try {
execute(connection, sql);
execute(session, connection, sql);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down Expand Up @@ -774,7 +780,7 @@ protected void renameTable(ConnectorSession session, String catalogName, String
protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName)
throws SQLException
{
execute(connection, format(
execute(session, connection, format(
"ALTER TABLE %s RENAME TO %s",
quoted(catalogName, remoteSchemaName, remoteTableName),
quoted(catalogName, newRemoteSchemaName, newRemoteTableName)));
Expand All @@ -800,9 +806,10 @@ private RemoteTableName constructPageSinkIdsTable(ConnectorSession session, Conn
String pageSinkInsertSql = format("INSERT INTO %s (%s) VALUES (?)",
quoted(pageSinkTable),
pageSinkIdColumnName);
pageSinkInsertSql = queryModifier.apply(session, pageSinkInsertSql);
LongWriteFunction pageSinkIdWriter = (LongWriteFunction) toWriteMapping(session, TRINO_PAGE_SINK_ID_COLUMN_TYPE).getWriteFunction();

execute(connection, pageSinkTableSql);
execute(session, connection, pageSinkTableSql);

try (PreparedStatement statement = connection.prepareStatement(pageSinkInsertSql)) {
int batchSize = 0;
Expand Down Expand Up @@ -868,7 +875,7 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha
handle.getPageSinkIdColumnName().get());
}

execute(connection, insertSql);
execute(session, connection, insertSql);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down Expand Up @@ -919,7 +926,7 @@ private void addColumn(ConnectorSession session, Connection connection, RemoteTa
"ALTER TABLE %s ADD %s",
quoted(table),
getColumnDefinitionSql(session, column, remoteColumnName));
execute(connection, sql);
execute(session, connection, sql);
}

@Override
Expand All @@ -939,7 +946,7 @@ public void renameColumn(ConnectorSession session, JdbcTableHandle handle, JdbcC
protected void renameColumn(ConnectorSession session, Connection connection, RemoteTableName remoteTableName, String remoteColumnName, String newRemoteColumnName)
throws SQLException
{
execute(connection, format(
execute(session, connection, format(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
quoted(remoteTableName),
quoted(remoteColumnName),
Expand All @@ -956,7 +963,7 @@ public void dropColumn(ConnectorSession session, JdbcTableHandle handle, JdbcCol
"ALTER TABLE %s DROP COLUMN %s",
quoted(handle.asPlainTable().getRemoteTableName()),
quoted(remoteColumnName));
execute(connection, sql);
execute(session, connection, sql);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down Expand Up @@ -1076,7 +1083,7 @@ public void createSchema(ConnectorSession session, String schemaName)
protected void createSchema(ConnectorSession session, Connection connection, String remoteSchemaName)
throws SQLException
{
execute(connection, "CREATE SCHEMA " + quoted(remoteSchemaName));
execute(session, connection, "CREATE SCHEMA " + quoted(remoteSchemaName));
}

@Override
Expand All @@ -1096,7 +1103,7 @@ public void dropSchema(ConnectorSession session, String schemaName)
protected void dropSchema(ConnectorSession session, Connection connection, String remoteSchemaName)
throws SQLException
{
execute(connection, "DROP SCHEMA " + quoted(remoteSchemaName));
execute(session, connection, "DROP SCHEMA " + quoted(remoteSchemaName));
}

@Override
Expand All @@ -1118,25 +1125,26 @@ public void renameSchema(ConnectorSession session, String schemaName, String new
protected void renameSchema(ConnectorSession session, Connection connection, String remoteSchemaName, String newRemoteSchemaName)
throws SQLException
{
execute(connection, "ALTER SCHEMA " + quoted(remoteSchemaName) + " RENAME TO " + quoted(newRemoteSchemaName));
execute(session, connection, "ALTER SCHEMA " + quoted(remoteSchemaName) + " RENAME TO " + quoted(newRemoteSchemaName));
}

protected void execute(ConnectorSession session, String query)
{
try (Connection connection = connectionFactory.openConnection(session)) {
execute(connection, query);
execute(session, connection, query);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

protected void execute(Connection connection, String query)
protected void execute(ConnectorSession session, Connection connection, String query)
throws SQLException
{
try (Statement statement = connection.createStatement()) {
log.debug("Execute: %s", query);
statement.execute(query);
String modifiedQuery = queryModifier.apply(session, query);
log.debug("Execute: %s", modifiedQuery);
statement.execute(modifiedQuery);
}
catch (SQLException e) {
e.addSuppressed(new RuntimeException("Query: " + query));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import com.google.common.base.Joiner;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.JoinType;
Expand All @@ -42,6 +44,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public class DefaultQueryBuilder
Expand All @@ -53,6 +56,19 @@ public class DefaultQueryBuilder
private static final String ALWAYS_TRUE = "1=1";
private static final String ALWAYS_FALSE = "1=0";

private final RemoteQueryModifier queryModifier;

public DefaultQueryBuilder()
{
this(RemoteQueryModifier.NONE);
}

@Inject
public DefaultQueryBuilder(RemoteQueryModifier queryModifier)
{
this.queryModifier = requireNonNull(queryModifier, "queryModifier is null");
}

@Override
public PreparedQuery prepareSelectQuery(
JdbcClient client,
Expand Down Expand Up @@ -163,8 +179,9 @@ public PreparedStatement prepareStatement(
PreparedQuery preparedQuery)
throws SQLException
{
log.debug("Preparing query: %s", preparedQuery.getQuery());
PreparedStatement statement = client.getPreparedStatement(connection, preparedQuery.getQuery());
String modifiedQuery = queryModifier.apply(session, preparedQuery.getQuery());
log.debug("Preparing query: %s", modifiedQuery);
PreparedStatement statement = client.getPreparedStatement(connection, modifiedQuery);

List<QueryParameter> parameters = preparedQuery.getParameters();
for (int i = 0; i < parameters.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.base.CatalogName;
import io.trino.plugin.base.session.SessionPropertiesProvider;
import io.trino.plugin.jdbc.logging.RemoteQueryModifierModule;
import io.trino.plugin.jdbc.mapping.IdentifierMappingModule;
import io.trino.plugin.jdbc.procedure.FlushJdbcMetadataCacheProcedure;
import io.trino.spi.connector.ConnectorAccessControl;
Expand Down Expand Up @@ -49,6 +50,7 @@ public void setup(Binder binder)
{
install(new JdbcDiagnosticModule());
install(new IdentifierMappingModule());
install(new RemoteQueryModifierModule());

newOptionalBinder(binder, ConnectorAccessControl.class);
newOptionalBinder(binder, QueryBuilder.class).setDefault().to(DefaultQueryBuilder.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -57,7 +58,7 @@ public class JdbcPageSink
private final LongWriteFunction pageSinkIdWriteFunction;
private final boolean includePageSinkIdColumn;

public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId)
public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId, RemoteQueryModifier remoteQueryModifier)
{
try {
connection = jdbcClient.getConnection(session, handle);
Expand Down Expand Up @@ -112,8 +113,12 @@ public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, Jdbc

String insertSql = jdbcClient.buildInsertSql(handle, columnWriters);
try {
insertSql = remoteQueryModifier.apply(session, insertSql);
statement = connection.prepareStatement(insertSql);
}
catch (TrinoException e) {
throw closeAllSuppress(e, connection);
}
catch (SQLException e) {
closeAllSuppress(e, connection);
throw new TrinoException(JDBC_ERROR, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.jdbc;

import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ConnectorInsertTableHandle;
import io.trino.spi.connector.ConnectorOutputTableHandle;
import io.trino.spi.connector.ConnectorPageSink;
Expand All @@ -29,22 +30,24 @@ public class JdbcPageSinkProvider
implements ConnectorPageSinkProvider
{
private final JdbcClient jdbcClient;
private RemoteQueryModifier queryModifier;

@Inject
public JdbcPageSinkProvider(JdbcClient jdbcClient)
public JdbcPageSinkProvider(JdbcClient jdbcClient, RemoteQueryModifier remoteQueryModifier)
{
this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null");
this.queryModifier = requireNonNull(remoteQueryModifier, "remoteQueryModifier is null");
}

@Override
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle tableHandle, ConnectorPageSinkId pageSinkId)
{
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId);
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId, queryModifier);
}

@Override
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorInsertTableHandle tableHandle, ConnectorPageSinkId pageSinkId)
{
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId);
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId, queryModifier);
}
}
Loading

0 comments on commit e2771a8

Please sign in to comment.