Skip to content

Commit

Permalink
GH-38576: [Java] Change JDBC driver to optionally preserve cookies an…
Browse files Browse the repository at this point in the history
…d auth tokens when getting streams (#38580)

### Rationale for this change
This change restores the original behavior of transmitting existing cookies and auth tokens when getting separate
streams returned by getFlightInfo after adding support for multiple endpoints.

These properties are now optional though.

### What changes are included in this PR?
- Change the JDBC driver to add new properties "retainCookies" and "retainAuth"
- These properties allow internally spawned connections for getting streams to use the cookies and bearer tokens from the original connection.
- Add tests for validating defaults from ArrowFlightSqlClient.Builder

### Are these changes tested?
Unit tests have been added.

### Are there any user-facing changes?
Yes. There are now properties and they are documented.
* Closes: #38576

Authored-by: James Duong <james.duong@improving.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
jduo authored Nov 7, 2023
1 parent 25c18d8 commit fafd48c
Show file tree
Hide file tree
Showing 12 changed files with 448 additions and 170 deletions.
11 changes: 11 additions & 0 deletions docs/source/java/flight_sql_jdbc_driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ case-sensitive. The supported parameters are:
- true
- When TLS is enabled, whether to use the system certificate store

* - retainCookies
- true
- Whether to use cookies from the initial connection in subsequent
internal connections when retrieving streams from separate endpoints.

* - retainAuth
- true
- Whether to use bearer tokens obtained from the initial connection
in subsequent internal connections used for retrieving streams
from separate endpoints.

Note that URI values must be URI-encoded if they contain characters such
as !, @, $, etc.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ClientIncomingAuthHeaderMiddleware implements FlightClientMiddlewar
*/
public static class Factory implements FlightClientMiddleware.Factory {
private final ClientHeaderHandler headerHandler;
private CredentialCallOption credentialCallOption;
private CredentialCallOption credentialCallOption = null;

/**
* Construct a factory with the given header handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withDisableCertificateVerification(config.getDisableCertificateVerification())
.withToken(config.getToken())
.withCallOptions(config.toCallOption())
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.build();
} catch (final SQLException e) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
Expand Down Expand Up @@ -425,29 +426,68 @@ public static final class Builder {
private final Set<CallOption> options = new HashSet<>();
private String host;
private int port;
private String username;
private String password;
private String trustStorePath;
private String trustStorePassword;
private String token;
private boolean useEncryption;
private boolean disableCertificateVerification;
private boolean useSystemTrustStore;
private String tlsRootCertificatesPath;
private String clientCertificatePath;
private String clientKeyPath;

@VisibleForTesting
String username;

@VisibleForTesting
String password;

@VisibleForTesting
String trustStorePath;

@VisibleForTesting
String trustStorePassword;

@VisibleForTesting
String token;

@VisibleForTesting
boolean useEncryption = true;

@VisibleForTesting
boolean disableCertificateVerification;

@VisibleForTesting
boolean useSystemTrustStore = true;

@VisibleForTesting
String tlsRootCertificatesPath;

@VisibleForTesting
String clientCertificatePath;

@VisibleForTesting
String clientKeyPath;

@VisibleForTesting
private BufferAllocator allocator;

public Builder() {
@VisibleForTesting
boolean retainCookies = true;

@VisibleForTesting
boolean retainAuth = true;

// These two middlewares are for internal use within build() and should not be exposed by builder APIs.
// Note that these middlewares may not necessarily be registered.
@VisibleForTesting
ClientIncomingAuthHeaderMiddleware.Factory authFactory
= new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());

@VisibleForTesting
ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();

public Builder() {
}

/**
* Copies the builder.
*
* @param original The builder to base this copy off of.
*/
private Builder(Builder original) {
@VisibleForTesting
Builder(Builder original) {
this.middlewareFactories.addAll(original.middlewareFactories);
this.options.addAll(original.options);
this.host = original.host;
Expand All @@ -464,6 +504,14 @@ private Builder(Builder original) {
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;

if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
}

if (original.retainAuth) {
this.authFactory = original.authFactory;
}
}

/**
Expand Down Expand Up @@ -622,6 +670,28 @@ public Builder withBufferAllocator(final BufferAllocator allocator) {
return this;
}

/**
* Indicates if cookies should be re-used by connections spawned for getStreams() calls.
* @param retainCookies The flag indicating if cookies should be re-used.
* @return this builder instance.
*/
public Builder withRetainCookies(boolean retainCookies) {
this.retainCookies = retainCookies;
return this;
}

/**
* Indicates if bearer tokens negotiated should be re-used by connections
* spawned for getStreams() calls.
*
* @param retainAuth The flag indicating if auth tokens should be re-used.
* @return this builder instance.
*/
public Builder withRetainAuth(boolean retainAuth) {
this.retainAuth = retainAuth;
return this;
}

/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler.
*
Expand Down Expand Up @@ -675,13 +745,11 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
// Copy middlewares so that the build method doesn't change the state of the builder fields itself.
Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories);
FlightClient client = null;
boolean isUsingUserPasswordAuth = username != null && token == null;

try {
ClientIncomingAuthHeaderMiddleware.Factory authFactory = null;
// Token should take priority since some apps pass in a username/password even when a token is provided
if (username != null && token == null) {
authFactory =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
Expand Down Expand Up @@ -722,10 +790,17 @@ public ArrowFlightSqlClientHandler build() throws SQLException {

client = clientBuilder.build();
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
if (authFactory != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing token.
// This can occur if the authFactory is being re-used for a new connection spawned for getStream().
if (authFactory.getCredentialCallOption() != null) {
credentialOptions.add(authFactory.getCredentialCallOption());
} else {
// Otherwise do the handshake and get the token if possible.
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
}
} else if (token != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ public int threadPoolSize() {
return ArrowFlightConnectionProperty.THREAD_POOL_SIZE.getInteger(properties);
}

/**
* Indicates if sub-connections created for stream retrieval
* should reuse cookies from the main connection.
*/
public boolean retainCookies() {
return ArrowFlightConnectionProperty.RETAIN_COOKIES.getBoolean(properties);
}

/**
* Indicates if sub-connections created for stream retrieval
* should reuse bearer tokens created from the main connection.
*/
public boolean retainAuth() {
return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties);
}

/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
Expand Down Expand Up @@ -191,7 +207,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
CLIENT_CERTIFICATE("clientCertificate", null, Type.STRING, false),
CLIENT_KEY("clientKey", null, Type.STRING, false),
THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false),
TOKEN("token", null, Type.STRING, false);
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false);

private final String camelName;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,21 @@ public void testGetEncryptedClientWithBadMTlsCertPath() {
final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(
userTest, passTest);

assertThrows(SQLException.class, () -> new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(badClientMTlsCertPath)
.withClientKey(clientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build());
assertThrows(SQLException.class, () -> {
try (ArrowFlightSqlClientHandler handler = new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(badClientMTlsCertPath)
.withClientKey(clientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build()) {
Assert.fail();
}
});
}

/**
Expand All @@ -162,17 +166,21 @@ public void testGetEncryptedClientWithBadMTlsKeyPath() {
final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(
userTest, passTest);

assertThrows(SQLException.class, () -> new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(clientMTlsCertPath)
.withClientKey(badClientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build());
assertThrows(SQLException.class, () -> {
try (ArrowFlightSqlClientHandler handler = new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(clientMTlsCertPath)
.withClientKey(badClientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build()) {
Assert.fail();
}
});
}

/**
Expand Down Expand Up @@ -222,7 +230,7 @@ public void testGetEncryptedConnectionWithValidCredentialsAndTlsRootsPath() thro
final ArrowFlightJdbcDataSource dataSource =
ArrowFlightJdbcDataSource.createNewDataSource(properties);
try (final Connection connection = dataSource.getConnection()) {
assert connection.isValid(300);
Assert.assertTrue(connection.isValid(300));
}
}

Expand All @@ -245,7 +253,7 @@ public void testGetNonAuthenticatedEncryptedConnection() throws Exception {

final ArrowFlightJdbcDataSource dataSource = ArrowFlightJdbcDataSource.createNewDataSource(properties);
try (final Connection connection = dataSource.getConnection()) {
assert connection.isValid(300);
Assert.assertTrue(connection.isValid(300));
}
}

Expand Down
Loading

0 comments on commit fafd48c

Please sign in to comment.