Skip to content

Commit

Permalink
GH-38576: [Java] Change JDBC driver to optionally preserve cookies an…
Browse files Browse the repository at this point in the history
…d auth tokens when getting streams

- Change the JDBC driver to add new properties "retainCookies" and "retainAuth"
- These properties allow internally spawned connections for getting streams to
use the cookies and bearer tokens from the original connection.
- Add tests for validating defaults from ArrowFlightSqlClient.Builder
  • Loading branch information
jduo committed Nov 6, 2023
1 parent 4ff1a29 commit c23d4be
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 24 deletions.
11 changes: 11 additions & 0 deletions docs/source/java/flight_sql_jdbc_driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ case-sensitive. The supported parameters are:
- true
- When TLS is enabled, whether to use the system certificate store

* - retainCookies
- true
- Whether to use cookies from the initial connection in subsequent
internal connections when retrieving streams from separate endpoints.

* - retainAuth
- true
- Whether to use bearer tokens obtained from the initial connection
in subsequent internal connections used for retrieving streams
from separate endpoints.

Note that URI values must be URI-encoded if they contain characters such
as !, @, $, etc.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ClientIncomingAuthHeaderMiddleware implements FlightClientMiddlewar
*/
public static class Factory implements FlightClientMiddleware.Factory {
private final ClientHeaderHandler headerHandler;
private CredentialCallOption credentialCallOption;
private CredentialCallOption credentialCallOption = null;

/**
* Construct a factory with the given header handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withDisableCertificateVerification(config.getDisableCertificateVerification())
.withToken(config.getToken())
.withCallOptions(config.toCallOption())
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.build();
} catch (final SQLException e) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
Expand Down Expand Up @@ -425,29 +426,68 @@ public static final class Builder {
private final Set<CallOption> options = new HashSet<>();
private String host;
private int port;
private String username;
private String password;
private String trustStorePath;
private String trustStorePassword;
private String token;
private boolean useEncryption;
private boolean disableCertificateVerification;
private boolean useSystemTrustStore;
private String tlsRootCertificatesPath;
private String clientCertificatePath;
private String clientKeyPath;

@VisibleForTesting
String username;

@VisibleForTesting
String password;

@VisibleForTesting
String trustStorePath;

@VisibleForTesting
String trustStorePassword;

@VisibleForTesting
String token;

@VisibleForTesting
boolean useEncryption = true;

@VisibleForTesting
boolean disableCertificateVerification;

@VisibleForTesting
boolean useSystemTrustStore = true;

@VisibleForTesting
String tlsRootCertificatesPath;

@VisibleForTesting
String clientCertificatePath;

@VisibleForTesting
String clientKeyPath;

@VisibleForTesting
private BufferAllocator allocator;

public Builder() {
@VisibleForTesting
boolean retainCookies = true;

@VisibleForTesting
boolean retainAuth = true;

// These two middlewares are for internal use within build() and should not be exposed by builder APIs.
// Note that these middlewares may not necessarily be registered.
@VisibleForTesting
ClientIncomingAuthHeaderMiddleware.Factory authFactory
= new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());

@VisibleForTesting
ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();

public Builder() {
}

/**
* Copies the builder.
*
* @param original The builder to base this copy off of.
*/
private Builder(Builder original) {
@VisibleForTesting
Builder(Builder original) {
this.middlewareFactories.addAll(original.middlewareFactories);
this.options.addAll(original.options);
this.host = original.host;
Expand All @@ -464,6 +504,14 @@ private Builder(Builder original) {
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;

if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
}

if (original.retainAuth) {
this.authFactory = original.authFactory;
}
}

/**
Expand Down Expand Up @@ -622,6 +670,28 @@ public Builder withBufferAllocator(final BufferAllocator allocator) {
return this;
}

/**
* Indicates if cookies should be re-used by connections spawned for getStreams() calls.
* @param retainCookies The flag indicating if cookies should be re-used.
* @return this builder instance.
*/
public Builder withRetainCookies(boolean retainCookies) {
this.retainCookies = retainCookies;
return this;
}

/**
* Indicates if bearer tokens negotiated should be re-used by connections
* spawned for getStreams() calls.
*
* @param retainAuth The flag indicating if auth tokens should be re-used.
* @return this builder instance.
*/
public Builder withRetainAuth(boolean retainAuth) {
this.retainAuth = retainAuth;
return this;
}

/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler.
*
Expand Down Expand Up @@ -675,13 +745,11 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
// Copy middlewares so that the build method doesn't change the state of the builder fields itself.
Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories);
FlightClient client = null;
boolean isUsingUserPasswordAuth = username != null && token == null;

try {
ClientIncomingAuthHeaderMiddleware.Factory authFactory = null;
// Token should take priority since some apps pass in a username/password even when a token is provided
if (username != null && token == null) {
authFactory =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
Expand Down Expand Up @@ -722,10 +790,17 @@ public ArrowFlightSqlClientHandler build() throws SQLException {

client = clientBuilder.build();
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
if (authFactory != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
if (username != null && token == null) {
// If the authFactory has already been used for a handshake, use the existing token.
// This can occur if the authFactory is being re-used for a new connection spawned for getStream().
if (authFactory.getCredentialCallOption() != null) {
credentialOptions.add(authFactory.getCredentialCallOption());
} else {
// Otherwise do the handshake and get the token if possible.
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
}
} else if (token != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ public int threadPoolSize() {
return ArrowFlightConnectionProperty.THREAD_POOL_SIZE.getInteger(properties);
}

/**
* Indicates if sub-connections created for stream retrieval
* should reuse cookies from the main connection.
*/
public boolean retainCookies() {
return ArrowFlightConnectionProperty.RETAIN_COOKIES.getBoolean(properties);
}

/**
* Indicates if sub-connections created for stream retrieval
* should reuse bearer tokens created from the main connection.
*/
public boolean retainAuth() {
return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties);
}

/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
Expand Down Expand Up @@ -191,7 +207,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
CLIENT_CERTIFICATE("clientCertificate", null, Type.STRING, false),
CLIENT_KEY("clientKey", null, Type.STRING, false),
THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false),
TOKEN("token", null, Type.STRING, false);
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false);

private final String camelName;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ public void testGetBasicClientAuthenticatedShouldOpenConnection()
new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withEncryption(false)
.withUsername(userTest)
.withPassword(passTest)
.withBufferAllocator(allocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public void testGetEncryptedClientAuthenticated() throws Exception {
new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withSystemTrustStore(false)
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTrustStorePath(trustStorePath)
Expand All @@ -153,6 +154,7 @@ public void testGetEncryptedClientWithNoCertificateOnKeyStore() throws Exception
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withTrustStorePath(noCertificateKeyStorePath)
.withTrustStorePassword(noCertificateKeyStorePassword)
.withSystemTrustStore(false)
.withBufferAllocator(allocator)
.withEncryption(true)
.build()) {
Expand All @@ -170,6 +172,7 @@ public void testGetNonAuthenticatedEncryptedClientNoAuth() throws Exception {
try (ArrowFlightSqlClientHandler client =
new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withSystemTrustStore(false)
.withTrustStorePath(trustStorePath)
.withTrustStorePassword(trustStorePass)
.withBufferAllocator(allocator)
Expand All @@ -192,6 +195,7 @@ public void testGetEncryptedClientWithKeyStoreBadPasswordAndNoAuth() throws Exce
try (ArrowFlightSqlClientHandler ignored =
new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withSystemTrustStore(false)
.withTrustStorePath(trustStorePath)
.withTrustStorePassword(keyStoreBadPassword)
.withBufferAllocator(allocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
* and interact with it.
*/
public class FlightServerTestRule implements TestRule, AutoCloseable {
public static final String DEFAULT_USER = "flight-test-user";
public static final String DEFAULT_PASSWORD = "flight-test-password";

private static final Logger LOGGER = LoggerFactory.getLogger(FlightServerTestRule.class);

private final Properties properties;
Expand Down Expand Up @@ -92,7 +95,7 @@ private FlightServerTestRule(final Properties properties,
public static FlightServerTestRule createStandardTestRule(final FlightSqlProducer producer) {
UserPasswordAuthentication authentication =
new UserPasswordAuthentication.Builder()
.user("flight-test-user", "flight-test-password")
.user(DEFAULT_USER, DEFAULT_PASSWORD)
.build();

return new Builder()
Expand Down
Loading

0 comments on commit c23d4be

Please sign in to comment.