Skip to content

Commit

Permalink
Improvements in Database metadata to prevent Statement leaks and enha…
Browse files Browse the repository at this point in the history
…nce Statement caching (#806)

* Closes Statements on completion - prevents Statement leaks from Driver

* Fix for Metadata Caching for no catalog scenarios.

* Handle case where user closes cached Prepared Statements from ResultSets

* Add Test for Prepared Statement Metadata Caching

* Changes as per recommendations

* Few more improvements
  • Loading branch information
cheenamalhotra authored Oct 27, 2018
1 parent a9bebfa commit cd3891c
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 82 deletions.
39 changes: 19 additions & 20 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java
Original file line number Diff line number Diff line change
Expand Up @@ -1700,9 +1700,11 @@ private void getDestinationMetadata() throws SQLServerException {
SQLServerException.getErrString("R_invalidDestinationTable"), null, false);
}

String escapedDestinationTableName = Util.escapeSingleQuotes(destinationTableName);

SQLServerResultSet rs = null;
SQLServerResultSet rsMoreMetaData = null;
SQLServerStatement stmt = null;
String metaDataQuery = null;

try {
if (null != destinationTableMetadata) {
Expand All @@ -1713,35 +1715,34 @@ private void getDestinationMetadata() throws SQLServerException {

// Get destination metadata
rs = stmt.executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM "
+ Util.escapeSingleQuotes(destinationTableName) + " '");
+ escapedDestinationTableName + " '");
}

destColumnCount = rs.getMetaData().getColumnCount();
destColumnMetadata = new HashMap<>();
destCekTable = rs.getCekTable();

if (!connection.getServerSupportsColumnEncryption()) {
// SQL server prior to 2016 does not support encryption_type
rsMoreMetaData = ((SQLServerStatement) connection.createStatement())
.executeQueryInternal("select collation_name from sys.columns where " + "object_id=OBJECT_ID('"
+ Util.escapeSingleQuotes(destinationTableName) + "') " + "order by column_id ASC");
metaDataQuery = "select collation_name from sys.columns where " + "object_id=OBJECT_ID('"
+ escapedDestinationTableName + "') " + "order by column_id ASC";
} else {
rsMoreMetaData = ((SQLServerStatement) connection.createStatement()).executeQueryInternal(
"select collation_name, encryption_type from sys.columns where " + "object_id=OBJECT_ID('"
+ Util.escapeSingleQuotes(destinationTableName) + "') " + "order by column_id ASC");
metaDataQuery = "select collation_name, encryption_type from sys.columns where "
+ "object_id=OBJECT_ID('" + escapedDestinationTableName + "') " + "order by column_id ASC";
}
for (int i = 1; i <= destColumnCount; ++i) {
if (rsMoreMetaData.next()) {
if (!connection.getServerSupportsColumnEncryption()) {

try (SQLServerStatement statementMoreMetadata = (SQLServerStatement) connection.createStatement();
SQLServerResultSet rsMoreMetaData = statementMoreMetadata.executeQueryInternal(metaDataQuery)) {
for (int i = 1; i <= destColumnCount; ++i) {
if (rsMoreMetaData.next()) {
String bulkCopyEncryptionType = null;
if (connection.getServerSupportsColumnEncryption()) {
bulkCopyEncryptionType = rsMoreMetaData.getString("encryption_type");
}
destColumnMetadata.put(i, new BulkColumnMetaData(rs.getColumn(i),
rsMoreMetaData.getString("collation_name"), null));
rsMoreMetaData.getString("collation_name"), bulkCopyEncryptionType));
} else {
destColumnMetadata.put(i,
new BulkColumnMetaData(rs.getColumn(i), rsMoreMetaData.getString("collation_name"),
rsMoreMetaData.getString("encryption_type")));
destColumnMetadata.put(i, new BulkColumnMetaData(rs.getColumn(i)));
}
} else {
destColumnMetadata.put(i, new BulkColumnMetaData(rs.getColumn(i)));
}
}
} catch (SQLException e) {
Expand All @@ -1752,8 +1753,6 @@ private void getDestinationMetadata() throws SQLServerException {
rs.close();
if (null != stmt)
stmt.close();
if (null != rsMoreMetaData)
rsMoreMetaData.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import java.sql.SQLTimeoutException;
import java.text.MessageFormat;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
Expand Down Expand Up @@ -80,16 +82,35 @@ CallableStatement prepare(SQLServerConnection conn) throws SQLServerException {
}

final class HandleAssociation {
final String databaseName;
final CallableStatement stmt;
Map<String, CallableStatement> statementMap;
boolean nullCatalog = false;
CallableStatement stmt;

HandleAssociation(String databaseName, CallableStatement stmt) {
this.databaseName = databaseName;
this.stmt = stmt;
HandleAssociation() {
if (null == statementMap) {
statementMap = new HashMap<>();
}
}

final void addToMap(String databaseName, CallableStatement stmt) {
if (null != databaseName) {
nullCatalog = false;
statementMap.put(databaseName, stmt);
} else {
nullCatalog = true;
this.stmt = stmt;
}
}

final void close() throws SQLServerException {
((SQLServerCallableStatement) stmt).close();
final CallableStatement getMappedStatement(String databaseName) {
if (null != databaseName) {
if (null != statementMap && statementMap.containsKey(databaseName)) {
return statementMap.get(databaseName);
}
return null;
} else {
return stmt;
}
}
}

Expand Down Expand Up @@ -244,13 +265,15 @@ private void checkClosed() throws SQLServerException {
* @throws SQLTimeoutException
*/
private SQLServerResultSet getResultSetFromInternalQueries(String catalog,
String query) throws SQLServerException, SQLTimeoutException {
String query) throws SQLException, SQLTimeoutException {
checkClosed();
String orgCat = null;
orgCat = switchCatalogs(catalog);
SQLServerResultSet rs = null;
try {
rs = ((SQLServerStatement) connection.createStatement()).executeQueryInternal(query);
SQLServerStatement statement = (SQLServerStatement) connection.createStatement();
statement.closeOnCompletion();
rs = statement.executeQueryInternal(query);
} finally {
if (null != orgCat) {
connection.setCatalog(orgCat);
Expand All @@ -266,15 +289,24 @@ private CallableStatement getCallableStatementHandle(CallableHandles request,
String catalog) throws SQLServerException {
CallableStatement CS = null;
HandleAssociation hassoc = handleMap.get(request);
if (null == hassoc || null == hassoc.databaseName || !hassoc.databaseName.equals(catalog)) {
CS = request.prepare(connection);
hassoc = new HandleAssociation(catalog, CS);
HandleAssociation previous = handleMap.put(request, hassoc);
if (null != previous) {
previous.close();
try {
if (null == hassoc) {
CS = request.prepare(connection);
hassoc = new HandleAssociation();
hassoc.addToMap(catalog, CS);
} else { // hassoc != null
CS = hassoc.getMappedStatement(catalog);
// No Cached Statement yet
if (null == CS || CS.isClosed()) {
CS = request.prepare(connection);
hassoc.addToMap(catalog, CS);
}
}
handleMap.put(request, hassoc);
} catch (SQLException e) {
SQLServerException.makeFromDriverError(connection, CS, e.toString(), null, false);
}
return hassoc.stmt;
return CS;
}

/**
Expand Down Expand Up @@ -411,7 +443,7 @@ public boolean supportsSharding() throws SQLException {
}

@Override
public java.sql.ResultSet getCatalogs() throws SQLServerException, SQLTimeoutException {
public java.sql.ResultSet getCatalogs() throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -762,7 +794,7 @@ public java.sql.ResultSet getBestRowIdentifier(String catalog, String schema, St

@Override
public java.sql.ResultSet getCrossReference(String cat1, String schem1, String tab1, String cat2, String schem2,
String tab2) throws SQLServerException, SQLTimeoutException {
String tab2) throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -827,7 +859,7 @@ public String getDriverVersion() throws SQLServerException {

@Override
public java.sql.ResultSet getExportedKeys(String cat, String schema,
String table) throws SQLServerException, SQLTimeoutException {
String table) throws SQLException, SQLTimeoutException {
return getCrossReference(cat, schema, table, null, null, null);
}

Expand All @@ -845,11 +877,11 @@ public String getIdentifierQuoteString() throws SQLServerException {

@Override
public java.sql.ResultSet getImportedKeys(String cat, String schema,
String table) throws SQLServerException, SQLTimeoutException {
String table) throws SQLException, SQLTimeoutException {
return getCrossReference(null, null, null, cat, schema, table);
}

private ResultSet executeSPFkeys(String[] procParams) throws SQLServerException, SQLTimeoutException {
private ResultSet executeSPFkeys(String[] procParams) throws SQLException, SQLTimeoutException {
String tempTableName = "@jdbc_temp_fkeys_result";
String sql = "DECLARE " + tempTableName + " table (PKTABLE_QUALIFIER sysname, " + "PKTABLE_OWNER sysname, "
+ "PKTABLE_NAME sysname, " + "PKCOLUMN_NAME sysname, " + "FKTABLE_QUALIFIER sysname, "
Expand All @@ -859,21 +891,8 @@ private ResultSet executeSPFkeys(String[] procParams) throws SQLServerException,
+ " EXEC sp_fkeys ?,?,?,?,?,?;" + "SELECT t.PKTABLE_QUALIFIER AS PKTABLE_CAT, "
+ "t.PKTABLE_OWNER AS PKTABLE_SCHEM, " + "t.PKTABLE_NAME, " + "t.PKCOLUMN_NAME, "
+ "t.FKTABLE_QUALIFIER AS FKTABLE_CAT, " + "t.FKTABLE_OWNER AS FKTABLE_SCHEM, " + "t.FKTABLE_NAME, "
+ "t.FKCOLUMN_NAME, " + "t.KEY_SEQ, " + "CASE s.update_referential_action " + "WHEN 1 THEN 0 " + // cascade
// -
// note
// that
// sp_fkey
// and
// sys.foreign_keys
// have
// flipped
// values
// for
// cascade
// and
// no
// action
+ "t.FKCOLUMN_NAME, " + "t.KEY_SEQ, " + "CASE s.update_referential_action " + "WHEN 1 THEN 0 " +
// cascade - note that sp_fkey and sys.foreign_keys have flipped values for cascade and no action
"WHEN 0 THEN 3 " + // no action
"WHEN 2 THEN 2 " + // set null
"WHEN 3 THEN 4 " + // set default
Expand All @@ -882,6 +901,7 @@ private ResultSet executeSPFkeys(String[] procParams) throws SQLServerException,
+ "t.DEFERRABILITY " + "FROM " + tempTableName + " t "
+ "LEFT JOIN sys.foreign_keys s ON t.FK_NAME = s.name collate database_default;";
SQLServerCallableStatement cstmt = (SQLServerCallableStatement) connection.prepareCall(sql);
cstmt.closeOnCompletion();
for (int i = 0; i < 6; i++) {
cstmt.setString(i + 1, procParams[i]);
}
Expand Down Expand Up @@ -987,7 +1007,7 @@ public int getMaxColumnsInTable() throws SQLServerException {
}

@Override
public int getMaxConnections() throws SQLServerException, SQLTimeoutException {
public int getMaxConnections() throws SQLException, SQLTimeoutException {
checkClosed();
try {
String s = "sp_configure 'user connections'";
Expand Down Expand Up @@ -1199,7 +1219,7 @@ public ResultSet getPseudoColumns(String catalog, String schemaPattern, String t
}

@Override
public java.sql.ResultSet getSchemas() throws SQLServerException, SQLTimeoutException {
public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand All @@ -1209,7 +1229,7 @@ public java.sql.ResultSet getSchemas() throws SQLServerException, SQLTimeoutExce
}

private java.sql.ResultSet getSchemasInternal(String catalog,
String schemaPattern) throws SQLServerException, SQLTimeoutException {
String schemaPattern) throws SQLException, SQLTimeoutException {

String s;
// The schemas that return null for catalog name, these are prebuilt
Expand Down Expand Up @@ -1268,10 +1288,10 @@ private java.sql.ResultSet getSchemasInternal(String catalog,
} else {

// The prepared statement is not closed after execution.
// Yes we will "leak a server handle" per execution but the
// connection closure will release them
//
// No we will not "leak a server handle" per execution
// as the prepared statement will close as the resultset 'rs' is closed
SQLServerPreparedStatement ps = (SQLServerPreparedStatement) connection.prepareStatement(s);
ps.closeOnCompletion();
ps.setString(1, schemaPattern);
rs = (SQLServerResultSet) ps.executeQueryInternal();
}
Expand Down Expand Up @@ -1374,7 +1394,7 @@ public java.sql.ResultSet getTablePrivileges(String catalog, String schema,
}

@Override
public java.sql.ResultSet getTableTypes() throws SQLServerException, SQLTimeoutException {
public java.sql.ResultSet getTableTypes() throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand All @@ -1391,7 +1411,7 @@ public String getTimeDateFunctions() throws SQLServerException {
}

@Override
public java.sql.ResultSet getTypeInfo() throws SQLServerException, SQLTimeoutException {
public java.sql.ResultSet getTypeInfo() throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -1481,25 +1501,13 @@ public String getUserName() throws SQLServerException, SQLTimeoutException {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
checkClosed();
SQLServerStatement s = null;
SQLServerResultSet rs = null;
String result = "";

try {
s = (SQLServerStatement) connection.createStatement();
rs = s.executeQueryInternal("select system_user");
try (SQLServerStatement s = (SQLServerStatement) connection.createStatement();
SQLServerResultSet rs = s.executeQueryInternal("select system_user")) {
// Select system_user will always return a row.
boolean next = rs.next();
assert next;

result = rs.getString(1);
} finally {
if (rs != null) {
rs.close();
}
if (s != null) {
s.close();
}
}
return result;
}
Expand Down Expand Up @@ -2142,7 +2150,7 @@ public boolean supportsBatchUpdates() throws SQLServerException {

@Override
public java.sql.ResultSet getUDTs(String catalog, String schemaPattern, String typeNamePattern,
int[] types) throws SQLServerException, SQLTimeoutException {
int[] types) throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -2244,7 +2252,7 @@ public boolean supportsResultSetHoldability(int holdability) throws SQLServerExc

@Override
public ResultSet getAttributes(String catalog, String schemaPattern, String typeNamePattern,
String attributeNamePattern) throws SQLServerException, SQLTimeoutException {
String attributeNamePattern) throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -2275,7 +2283,7 @@ public ResultSet getAttributes(String catalog, String schemaPattern, String type

@Override
public ResultSet getSuperTables(String catalog, String schemaPattern,
String tableNamePattern) throws SQLServerException, SQLTimeoutException {
String tableNamePattern) throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand All @@ -2289,7 +2297,7 @@ public ResultSet getSuperTables(String catalog, String schemaPattern,

@Override
public ResultSet getSuperTypes(String catalog, String schemaPattern,
String typeNamePattern) throws SQLServerException, SQLTimeoutException {
String typeNamePattern) throws SQLException, SQLTimeoutException {
if (loggerExternal.isLoggable(Level.FINER) && Util.IsActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down
Loading

0 comments on commit cd3891c

Please sign in to comment.