Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context appender to sql based connectors #14500

Merged
merged 2 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ public static Builder builder()
public static class Builder
s2lomon marked this conversation as resolved.
Show resolved Hide resolved
{
private ConnectorIdentity identity = ConnectorIdentity.ofUser("user");
private final Optional<String> source = Optional.of("test");
private Optional<String> source = Optional.of("test");
private TimeZoneKey timeZoneKey = UTC_KEY;
private final Locale locale = ENGLISH;
private final Optional<String> traceToken = Optional.empty();
private Optional<String> traceToken = Optional.empty();
private Optional<Instant> start = Optional.empty();
private List<PropertyMetadata<?>> propertyMetadatas = ImmutableList.of();
private Map<String, Object> propertyValues = ImmutableMap.of();
Expand All @@ -177,6 +177,18 @@ public Builder setStart(Instant start)
return this;
}

public Builder setSource(String source)
s2lomon marked this conversation as resolved.
Show resolved Hide resolved
{
this.source = Optional.of(source);
return this;
}

public Builder setTraceToken(String token)
{
this.traceToken = Optional.of(token);
return this;
}

public Builder setPropertyMetadata(List<PropertyMetadata<?>> propertyMetadatas)
{
requireNonNull(propertyMetadatas, "propertyMetadatas is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.io.Closer;
import io.airlift.log.Logger;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
Expand Down Expand Up @@ -104,6 +105,7 @@ public abstract class BaseJdbcClient
protected final String identifierQuote;
protected final Set<String> jdbcTypesMappedToVarchar;
private final IdentifierMapping identifierMapping;
private final RemoteQueryModifier queryModifier;

private final boolean supportsRetries;

Expand All @@ -112,14 +114,16 @@ public BaseJdbcClient(
String identifierQuote,
ConnectionFactory connectionFactory,
QueryBuilder queryBuilder,
IdentifierMapping identifierMapping)
IdentifierMapping identifierMapping,
RemoteQueryModifier remoteQueryModifier)
{
this(
identifierQuote,
connectionFactory,
queryBuilder,
config.getJdbcTypesMappedToVarchar(),
identifierMapping,
remoteQueryModifier,
false);
}

Expand All @@ -129,6 +133,7 @@ public BaseJdbcClient(
QueryBuilder queryBuilder,
Set<String> jdbcTypesMappedToVarchar,
IdentifierMapping identifierMapping,
RemoteQueryModifier remoteQueryModifier,
boolean supportsRetries)
{
this.identifierQuote = requireNonNull(identifierQuote, "identifierQuote is null");
Expand All @@ -138,6 +143,7 @@ public BaseJdbcClient(
.addAll(requireNonNull(jdbcTypesMappedToVarchar, "jdbcTypesMappedToVarchar is null"))
.build();
this.identifierMapping = requireNonNull(identifierMapping, "identifierMapping is null");
this.queryModifier = requireNonNull(remoteQueryModifier, "remoteQueryModifier is null");
this.supportsRetries = supportsRetries;
}

Expand Down Expand Up @@ -610,7 +616,7 @@ protected JdbcOutputTableHandle createTable(ConnectorSession session, ConnectorT

RemoteTableName remoteTableName = new RemoteTableName(Optional.ofNullable(catalog), Optional.ofNullable(remoteSchema), remoteTargetTableName);
String sql = createTableSql(remoteTableName, columnList.build(), tableMetadata);
execute(connection, sql);
execute(session, connection, sql);

return new JdbcOutputTableHandle(
catalog,
Expand Down Expand Up @@ -682,7 +688,7 @@ public JdbcOutputTableHandle beginInsertTable(ConnectorSession session, JdbcTabl
}

String remoteTemporaryTableName = identifierMapping.toRemoteTableName(identity, connection, remoteSchema, generateTemporaryTableName());
copyTableSchema(connection, catalog, remoteSchema, remoteTable, remoteTemporaryTableName, columnNames.build());
copyTableSchema(session, connection, catalog, remoteSchema, remoteTable, remoteTemporaryTableName, columnNames.build());

Optional<ColumnMetadata> pageSinkIdColumn = Optional.empty();
if (shouldUseFaultTolerantExecution(session)) {
Expand All @@ -709,7 +715,7 @@ public JdbcOutputTableHandle beginInsertTable(ConnectorSession session, JdbcTabl
}
}

protected void copyTableSchema(Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List<String> columnNames)
protected void copyTableSchema(ConnectorSession session, Connection connection, String catalogName, String schemaName, String tableName, String newTableName, List<String> columnNames)
{
String sql = format(
"CREATE TABLE %s AS SELECT %s FROM %s WHERE 0 = 1",
Expand All @@ -719,7 +725,7 @@ protected void copyTableSchema(Connection connection, String catalogName, String
.collect(joining(", ")),
quoted(catalogName, schemaName, tableName));
try {
execute(connection, sql);
execute(session, connection, sql);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down Expand Up @@ -774,7 +780,7 @@ protected void renameTable(ConnectorSession session, String catalogName, String
protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName)
throws SQLException
{
execute(connection, format(
execute(session, connection, format(
"ALTER TABLE %s RENAME TO %s",
quoted(catalogName, remoteSchemaName, remoteTableName),
quoted(catalogName, newRemoteSchemaName, newRemoteTableName)));
Expand All @@ -800,9 +806,10 @@ private RemoteTableName constructPageSinkIdsTable(ConnectorSession session, Conn
String pageSinkInsertSql = format("INSERT INTO %s (%s) VALUES (?)",
quoted(pageSinkTable),
pageSinkIdColumnName);
pageSinkInsertSql = queryModifier.apply(session, pageSinkInsertSql);
LongWriteFunction pageSinkIdWriter = (LongWriteFunction) toWriteMapping(session, TRINO_PAGE_SINK_ID_COLUMN_TYPE).getWriteFunction();

execute(connection, pageSinkTableSql);
execute(session, connection, pageSinkTableSql);

try (PreparedStatement statement = connection.prepareStatement(pageSinkInsertSql)) {
int batchSize = 0;
Expand Down Expand Up @@ -868,7 +875,7 @@ public void finishInsertTable(ConnectorSession session, JdbcOutputTableHandle ha
handle.getPageSinkIdColumnName().get());
}

execute(connection, insertSql);
execute(session, connection, insertSql);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down Expand Up @@ -919,7 +926,7 @@ private void addColumn(ConnectorSession session, Connection connection, RemoteTa
"ALTER TABLE %s ADD %s",
quoted(table),
getColumnDefinitionSql(session, column, remoteColumnName));
execute(connection, sql);
execute(session, connection, sql);
}

@Override
Expand All @@ -939,7 +946,7 @@ public void renameColumn(ConnectorSession session, JdbcTableHandle handle, JdbcC
protected void renameColumn(ConnectorSession session, Connection connection, RemoteTableName remoteTableName, String remoteColumnName, String newRemoteColumnName)
throws SQLException
{
execute(connection, format(
execute(session, connection, format(
"ALTER TABLE %s RENAME COLUMN %s TO %s",
quoted(remoteTableName),
quoted(remoteColumnName),
Expand All @@ -956,7 +963,7 @@ public void dropColumn(ConnectorSession session, JdbcTableHandle handle, JdbcCol
"ALTER TABLE %s DROP COLUMN %s",
quoted(handle.asPlainTable().getRemoteTableName()),
quoted(remoteColumnName));
execute(connection, sql);
execute(session, connection, sql);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
Expand Down Expand Up @@ -1076,7 +1083,7 @@ public void createSchema(ConnectorSession session, String schemaName)
protected void createSchema(ConnectorSession session, Connection connection, String remoteSchemaName)
throws SQLException
{
execute(connection, "CREATE SCHEMA " + quoted(remoteSchemaName));
execute(session, connection, "CREATE SCHEMA " + quoted(remoteSchemaName));
}

@Override
Expand All @@ -1096,7 +1103,7 @@ public void dropSchema(ConnectorSession session, String schemaName)
protected void dropSchema(ConnectorSession session, Connection connection, String remoteSchemaName)
throws SQLException
{
execute(connection, "DROP SCHEMA " + quoted(remoteSchemaName));
execute(session, connection, "DROP SCHEMA " + quoted(remoteSchemaName));
}

@Override
Expand All @@ -1118,25 +1125,26 @@ public void renameSchema(ConnectorSession session, String schemaName, String new
protected void renameSchema(ConnectorSession session, Connection connection, String remoteSchemaName, String newRemoteSchemaName)
throws SQLException
{
execute(connection, "ALTER SCHEMA " + quoted(remoteSchemaName) + " RENAME TO " + quoted(newRemoteSchemaName));
execute(session, connection, "ALTER SCHEMA " + quoted(remoteSchemaName) + " RENAME TO " + quoted(newRemoteSchemaName));
}

protected void execute(ConnectorSession session, String query)
{
try (Connection connection = connectionFactory.openConnection(session)) {
execute(connection, query);
execute(session, connection, query);
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

protected void execute(Connection connection, String query)
protected void execute(ConnectorSession session, Connection connection, String query)
throws SQLException
{
try (Statement statement = connection.createStatement()) {
log.debug("Execute: %s", query);
statement.execute(query);
String modifiedQuery = queryModifier.apply(session, query);
log.debug("Execute: %s", modifiedQuery);
statement.execute(modifiedQuery);
}
catch (SQLException e) {
e.addSuppressed(new RuntimeException("Query: " + query));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import com.google.common.base.Joiner;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.JoinType;
Expand All @@ -42,6 +44,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public class DefaultQueryBuilder
Expand All @@ -53,6 +56,19 @@ public class DefaultQueryBuilder
private static final String ALWAYS_TRUE = "1=1";
private static final String ALWAYS_FALSE = "1=0";

private final RemoteQueryModifier queryModifier;

public DefaultQueryBuilder()
s2lomon marked this conversation as resolved.
Show resolved Hide resolved
{
this(RemoteQueryModifier.NONE);
}

@Inject
public DefaultQueryBuilder(RemoteQueryModifier queryModifier)
{
this.queryModifier = requireNonNull(queryModifier, "queryModifier is null");
}

@Override
public PreparedQuery prepareSelectQuery(
JdbcClient client,
Expand Down Expand Up @@ -163,8 +179,9 @@ public PreparedStatement prepareStatement(
PreparedQuery preparedQuery)
throws SQLException
{
log.debug("Preparing query: %s", preparedQuery.getQuery());
PreparedStatement statement = client.getPreparedStatement(connection, preparedQuery.getQuery());
String modifiedQuery = queryModifier.apply(session, preparedQuery.getQuery());
log.debug("Preparing query: %s", modifiedQuery);
PreparedStatement statement = client.getPreparedStatement(connection, modifiedQuery);

List<QueryParameter> parameters = preparedQuery.getParameters();
for (int i = 0; i < parameters.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.base.CatalogName;
import io.trino.plugin.base.session.SessionPropertiesProvider;
import io.trino.plugin.jdbc.logging.RemoteQueryModifierModule;
import io.trino.plugin.jdbc.mapping.IdentifierMappingModule;
import io.trino.plugin.jdbc.procedure.FlushJdbcMetadataCacheProcedure;
import io.trino.spi.connector.ConnectorAccessControl;
Expand Down Expand Up @@ -49,6 +50,7 @@ public void setup(Binder binder)
{
install(new JdbcDiagnosticModule());
install(new IdentifierMappingModule());
install(new RemoteQueryModifierModule());

newOptionalBinder(binder, ConnectorAccessControl.class);
newOptionalBinder(binder, QueryBuilder.class).setDefault().to(DefaultQueryBuilder.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -57,7 +58,7 @@ public class JdbcPageSink
private final LongWriteFunction pageSinkIdWriteFunction;
private final boolean includePageSinkIdColumn;

public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId)
public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient, ConnectorPageSinkId pageSinkId, RemoteQueryModifier remoteQueryModifier)
{
try {
connection = jdbcClient.getConnection(session, handle);
Expand Down Expand Up @@ -112,8 +113,12 @@ public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, Jdbc

String insertSql = jdbcClient.buildInsertSql(handle, columnWriters);
try {
insertSql = remoteQueryModifier.apply(session, insertSql);
statement = connection.prepareStatement(insertSql);
}
catch (TrinoException e) {
throw closeAllSuppress(e, connection);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would throw new TrinoException with e as cause. So all you need to do is catch (SQLException | TrinoException e)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that it's correct from the exception handling stnadnpoint. We already have the exception and there is no need to wrap it again without adding any meaningful information to the caller. Plus we are not returning JDBC_ERROR but rather NON_TRANSIENT_JDBC_ERROR wich is the whole point of this exception being thrown.

}
catch (SQLException e) {
closeAllSuppress(e, connection);
throw new TrinoException(JDBC_ERROR, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.jdbc;

import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.ConnectorInsertTableHandle;
import io.trino.spi.connector.ConnectorOutputTableHandle;
import io.trino.spi.connector.ConnectorPageSink;
Expand All @@ -29,22 +30,24 @@ public class JdbcPageSinkProvider
implements ConnectorPageSinkProvider
{
private final JdbcClient jdbcClient;
private RemoteQueryModifier queryModifier;

@Inject
public JdbcPageSinkProvider(JdbcClient jdbcClient)
public JdbcPageSinkProvider(JdbcClient jdbcClient, RemoteQueryModifier remoteQueryModifier)
{
this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null");
this.queryModifier = requireNonNull(remoteQueryModifier, "remoteQueryModifier is null");
}

@Override
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle tableHandle, ConnectorPageSinkId pageSinkId)
{
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId);
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId, queryModifier);
}

@Override
public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorInsertTableHandle tableHandle, ConnectorPageSinkId pageSinkId)
{
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId);
return new JdbcPageSink(session, (JdbcOutputTableHandle) tableHandle, jdbcClient, pageSinkId, queryModifier);
}
}
Loading