diff --git a/tempto-core/src/main/java/io/trino/tempto/query/JdbcConnectionsPool.java b/tempto-core/src/main/java/io/trino/tempto/query/JdbcConnectionsPool.java index f34d9b12..27d82076 100644 --- a/tempto-core/src/main/java/io/trino/tempto/query/JdbcConnectionsPool.java +++ b/tempto-core/src/main/java/io/trino/tempto/query/JdbcConnectionsPool.java @@ -22,6 +22,7 @@ import static com.google.common.collect.Maps.newHashMap; import static io.trino.tempto.internal.query.JdbcUtils.dataSource; +import static java.util.Objects.requireNonNull; public class JdbcConnectionsPool { @@ -29,6 +30,12 @@ public class JdbcConnectionsPool public Connection connectionFor(JdbcConnectivityParamsState jdbcParamsState) throws SQLException + { + return configureConnection(jdbcParamsState, createConnection(jdbcParamsState)); + } + + protected Connection createConnection(JdbcConnectivityParamsState jdbcParamsState) + throws SQLException { if (!dataSources.containsKey(jdbcParamsState)) { dataSources.put(jdbcParamsState, dataSource(jdbcParamsState)); @@ -39,6 +46,13 @@ public Connection connectionFor(JdbcConnectivityParamsState jdbcParamsState) // this should never happen, `javax.sql.DataSource#getConnection()` should not return null throw new IllegalStateException("No connection was created for: " + jdbcParamsState.getName()); } + return connection; + } + + protected static Connection configureConnection(JdbcConnectivityParamsState jdbcParamsState, Connection connection) + throws SQLException + { + requireNonNull(connection, "connection is null"); if (!jdbcParamsState.prepareStatements.isEmpty()) { try (Statement statement = connection.createStatement()) { for (String query : jdbcParamsState.prepareStatements) {