Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38576: [Java] Change JDBC driver to optionally preserve cookies and auth tokens when getting streams #38580

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/java/flight_sql_jdbc_driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ case-sensitive. The supported parameters are:
- true
- When TLS is enabled, whether to use the system certificate store

* - retainCookies
- true
- Whether to use cookies from the initial connection in subsequent
internal connections when retrieving streams from separate endpoints.

* - retainAuth
- true
- Whether to use bearer tokens obtained from the initial connection
in subsequent internal connections used for retrieving streams
from separate endpoints.

Note that URI values must be URI-encoded if they contain characters such
as !, @, $, etc.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ClientIncomingAuthHeaderMiddleware implements FlightClientMiddlewar
*/
public static class Factory implements FlightClientMiddleware.Factory {
private final ClientHeaderHandler headerHandler;
private CredentialCallOption credentialCallOption;
private CredentialCallOption credentialCallOption = null;

/**
* Construct a factory with the given header handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withDisableCertificateVerification(config.getDisableCertificateVerification())
.withToken(config.getToken())
.withCallOptions(config.toCallOption())
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.build();
} catch (final SQLException e) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
Expand Down Expand Up @@ -425,29 +426,68 @@ public static final class Builder {
private final Set<CallOption> options = new HashSet<>();
private String host;
private int port;
private String username;
private String password;
private String trustStorePath;
private String trustStorePassword;
private String token;
private boolean useEncryption;
private boolean disableCertificateVerification;
private boolean useSystemTrustStore;
private String tlsRootCertificatesPath;
private String clientCertificatePath;
private String clientKeyPath;

@VisibleForTesting
String username;

@VisibleForTesting
String password;

@VisibleForTesting
String trustStorePath;

@VisibleForTesting
String trustStorePassword;

@VisibleForTesting
String token;

@VisibleForTesting
boolean useEncryption = true;

@VisibleForTesting
boolean disableCertificateVerification;

@VisibleForTesting
boolean useSystemTrustStore = true;

@VisibleForTesting
String tlsRootCertificatesPath;

@VisibleForTesting
String clientCertificatePath;

@VisibleForTesting
String clientKeyPath;

@VisibleForTesting
private BufferAllocator allocator;

public Builder() {
@VisibleForTesting
boolean retainCookies = true;

@VisibleForTesting
boolean retainAuth = true;

// These two middlewares are for internal use within build() and should not be exposed by builder APIs.
// Note that these middlewares may not necessarily be registered.
@VisibleForTesting
ClientIncomingAuthHeaderMiddleware.Factory authFactory
= new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());

@VisibleForTesting
ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();

public Builder() {
}

/**
* Copies the builder.
*
* @param original The builder to base this copy off of.
*/
private Builder(Builder original) {
@VisibleForTesting
Builder(Builder original) {
this.middlewareFactories.addAll(original.middlewareFactories);
this.options.addAll(original.options);
this.host = original.host;
Expand All @@ -464,6 +504,14 @@ private Builder(Builder original) {
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;

if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
}

if (original.retainAuth) {
this.authFactory = original.authFactory;
}
}

/**
Expand Down Expand Up @@ -622,6 +670,28 @@ public Builder withBufferAllocator(final BufferAllocator allocator) {
return this;
}

/**
* Indicates if cookies should be re-used by connections spawned for getStreams() calls.
* @param retainCookies The flag indicating if cookies should be re-used.
* @return this builder instance.
*/
public Builder withRetainCookies(boolean retainCookies) {
this.retainCookies = retainCookies;
return this;
}

/**
* Indicates if bearer tokens negotiated should be re-used by connections
* spawned for getStreams() calls.
*
* @param retainAuth The flag indicating if auth tokens should be re-used.
* @return this builder instance.
*/
public Builder withRetainAuth(boolean retainAuth) {
this.retainAuth = retainAuth;
return this;
}

/**
* Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler.
*
Expand Down Expand Up @@ -675,13 +745,11 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
// Copy middlewares so that the build method doesn't change the state of the builder fields itself.
Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories);
FlightClient client = null;
boolean isUsingUserPasswordAuth = username != null && token == null;

try {
ClientIncomingAuthHeaderMiddleware.Factory authFactory = null;
// Token should take priority since some apps pass in a username/password even when a token is provided
if (username != null && token == null) {
authFactory =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
Expand Down Expand Up @@ -722,10 +790,17 @@ public ArrowFlightSqlClientHandler build() throws SQLException {

client = clientBuilder.build();
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
if (authFactory != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing token.
// This can occur if the authFactory is being re-used for a new connection spawned for getStream().
if (authFactory.getCredentialCallOption() != null) {
credentialOptions.add(authFactory.getCredentialCallOption());
} else {
// Otherwise do the handshake and get the token if possible.
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
}
} else if (token != null) {
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ public int threadPoolSize() {
return ArrowFlightConnectionProperty.THREAD_POOL_SIZE.getInteger(properties);
}

/**
* Indicates if sub-connections created for stream retrieval
* should reuse cookies from the main connection.
*/
public boolean retainCookies() {
return ArrowFlightConnectionProperty.RETAIN_COOKIES.getBoolean(properties);
}

/**
* Indicates if sub-connections created for stream retrieval
* should reuse bearer tokens created from the main connection.
*/
public boolean retainAuth() {
return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties);
}

/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
Expand Down Expand Up @@ -191,7 +207,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
CLIENT_CERTIFICATE("clientCertificate", null, Type.STRING, false),
CLIENT_KEY("clientKey", null, Type.STRING, false),
THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false),
TOKEN("token", null, Type.STRING, false);
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false);

private final String camelName;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,21 @@ public void testGetEncryptedClientWithBadMTlsCertPath() {
final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(
userTest, passTest);

assertThrows(SQLException.class, () -> new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(badClientMTlsCertPath)
.withClientKey(clientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build());
assertThrows(SQLException.class, () -> {
try (ArrowFlightSqlClientHandler handler = new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(badClientMTlsCertPath)
.withClientKey(clientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build()) {
Assert.fail();
}
});
}

/**
Expand All @@ -162,17 +166,21 @@ public void testGetEncryptedClientWithBadMTlsKeyPath() {
final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials(
userTest, passTest);

assertThrows(SQLException.class, () -> new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(clientMTlsCertPath)
.withClientKey(badClientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build());
assertThrows(SQLException.class, () -> {
try (ArrowFlightSqlClientHandler handler = new ArrowFlightSqlClientHandler.Builder()
.withHost(FLIGHT_SERVER_TEST_RULE.getHost())
.withPort(FLIGHT_SERVER_TEST_RULE.getPort())
.withUsername(credentials.getUserName())
.withPassword(credentials.getPassword())
.withTlsRootCertificates(tlsRootCertsPath)
.withClientCertificate(clientMTlsCertPath)
.withClientKey(badClientMTlsKeyPath)
.withBufferAllocator(allocator)
.withEncryption(true)
.build()) {
Assert.fail();
}
});
}

/**
Expand Down Expand Up @@ -222,7 +230,7 @@ public void testGetEncryptedConnectionWithValidCredentialsAndTlsRootsPath() thro
final ArrowFlightJdbcDataSource dataSource =
ArrowFlightJdbcDataSource.createNewDataSource(properties);
try (final Connection connection = dataSource.getConnection()) {
assert connection.isValid(300);
Assert.assertTrue(connection.isValid(300));
}
}

Expand All @@ -245,7 +253,7 @@ public void testGetNonAuthenticatedEncryptedConnection() throws Exception {

final ArrowFlightJdbcDataSource dataSource = ArrowFlightJdbcDataSource.createNewDataSource(properties);
try (final Connection connection = dataSource.getConnection()) {
assert connection.isValid(300);
Assert.assertTrue(connection.isValid(300));
}
}

Expand Down
Loading
Loading