Skip to content

Commit

Permalink
GH-34532: [Java] Change JDBC to handle multi-endpoints
Browse files Browse the repository at this point in the history
- Create new clients to connect to new locations in endpoints.
- If no location is reported using the current connection.
- Change connecting to each endpoint to happen asynchronously
instead of iteratively.
- Change stream clean-up to be done as soon as a stream is
finished instead of at the end of the result set.
  • Loading branch information
jduo committed Oct 30, 2023
1 parent cb11e44 commit cfec8d0
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.arrow.driver.jdbc;

import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue;
import static org.apache.arrow.driver.jdbc.utils.FlightEndpointDataQueue.createNewQueue;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
Expand All @@ -26,7 +26,8 @@
import java.util.TimeZone;
import java.util.concurrent.TimeUnit;

import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue;
import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair;
import org.apache.arrow.driver.jdbc.utils.FlightEndpointDataQueue;
import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
Expand All @@ -47,8 +48,8 @@ public final class ArrowFlightJdbcFlightStreamResultSet
extends ArrowFlightJdbcVectorSchemaRootResultSet {

private final ArrowFlightConnection connection;
private FlightStream currentFlightStream;
private FlightStreamQueue flightStreamQueue;
private CloseableEndpointStreamPair currentEndpointData;
private FlightEndpointDataQueue flightEndpointDataQueue;

private VectorSchemaRootTransformer transformer;
private VectorSchemaRoot currentVectorSchemaRoot;
Expand Down Expand Up @@ -102,20 +103,20 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo(

resultSet.transformer = transformer;

resultSet.execute(flightInfo);
resultSet.populateData(flightInfo);
return resultSet;
}

private void loadNewQueue() {
Optional.ofNullable(flightStreamQueue).ifPresent(AutoCloseables::closeNoChecked);
flightStreamQueue = createNewQueue(connection.getExecutorService());
Optional.ofNullable(flightEndpointDataQueue).ifPresent(AutoCloseables::closeNoChecked);
flightEndpointDataQueue = createNewQueue(connection.getExecutorService());
}

private void loadNewFlightStream() throws SQLException {
if (currentFlightStream != null) {
AutoCloseables.closeNoChecked(currentFlightStream);
if (currentEndpointData != null) {
AutoCloseables.closeNoChecked(currentEndpointData);
}
this.currentFlightStream = getNextFlightStream(true);
this.currentEndpointData = getNextEndpointStream(true);
}

@Override
Expand All @@ -124,24 +125,24 @@ protected AvaticaResultSet execute() throws SQLException {

if (flightInfo != null) {
schema = flightInfo.getSchemaOptional().orElse(null);
execute(flightInfo);
populateData(flightInfo);
}
return this;
}

private void execute(final FlightInfo flightInfo) throws SQLException {
private void populateData(final FlightInfo flightInfo) throws SQLException {
loadNewQueue();
flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo));
flightEndpointDataQueue.enqueue(connection.getClientHandler().getStreams(flightInfo));
loadNewFlightStream();

// Ownership of the root will be passed onto the cursor.
if (currentFlightStream != null) {
executeForCurrentFlightStream();
if (currentEndpointData != null) {
populateDataForCurrentFlightStream();
}
}

private void executeForCurrentFlightStream() throws SQLException {
final VectorSchemaRoot originalRoot = currentFlightStream.getRoot();
private void populateDataForCurrentFlightStream() throws SQLException {
final VectorSchemaRoot originalRoot = currentEndpointData.getStream().getRoot();

if (transformer != null) {
try {
Expand All @@ -154,9 +155,9 @@ private void executeForCurrentFlightStream() throws SQLException {
}

if (schema != null) {
execute(currentVectorSchemaRoot, schema);
populateData(currentVectorSchemaRoot, schema);
} else {
execute(currentVectorSchemaRoot);
populateData(currentVectorSchemaRoot);
}
}

Expand All @@ -179,20 +180,23 @@ public boolean next() throws SQLException {
return true;
}

if (currentFlightStream != null) {
currentFlightStream.getRoot().clear();
if (currentFlightStream.next()) {
executeForCurrentFlightStream();
if (currentEndpointData != null) {
if (currentEndpointData.getStream().next()) {
populateDataForCurrentFlightStream();
continue;
}

flightStreamQueue.enqueue(currentFlightStream);
try {
AutoCloseables.close(currentEndpointData);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}

currentFlightStream = getNextFlightStream(false);
currentEndpointData = getNextEndpointStream(false);

if (currentFlightStream != null) {
executeForCurrentFlightStream();
if (currentEndpointData != null) {
populateDataForCurrentFlightStream();
continue;
}

Expand All @@ -207,14 +211,19 @@ public boolean next() throws SQLException {
@Override
protected void cancel() {
super.cancel();
final FlightStream currentFlightStream = this.currentFlightStream;
final CloseableEndpointStreamPair currentFlightStream = this.currentEndpointData;
if (currentFlightStream != null) {
currentFlightStream.cancel("Cancel", null);
currentFlightStream.getStream().cancel("Cancel", null);
try {
currentFlightStream.close();
} catch (final Exception e) {
throw new RuntimeException(e);
}
}

if (flightStreamQueue != null) {
if (flightEndpointDataQueue != null) {
try {
flightStreamQueue.close();
flightEndpointDataQueue.close();
} catch (final Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -224,12 +233,12 @@ protected void cancel() {
@Override
public synchronized void close() {
try {
if (flightStreamQueue != null) {
if (flightEndpointDataQueue != null) {
// flightStreamQueue should close currentFlightStream internally
flightStreamQueue.close();
} else if (currentFlightStream != null) {
flightEndpointDataQueue.close();
} else if (currentEndpointData != null) {
// close is only called for currentFlightStream if there's no queue
currentFlightStream.close();
currentEndpointData.close();
}
} catch (final Exception e) {
throw new RuntimeException(e);
Expand All @@ -238,13 +247,13 @@ public synchronized void close() {
}
}

private FlightStream getNextFlightStream(final boolean isExecution) throws SQLException {
if (isExecution) {
private CloseableEndpointStreamPair getNextEndpointStream(final boolean canTimeout) throws SQLException {
if (canTimeout) {
final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0;
return statementTimeout != 0 ?
flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next();
flightEndpointDataQueue.next(statementTimeout, TimeUnit.SECONDS) : flightEndpointDataQueue.next();
} else {
return flightStreamQueue.next();
return flightEndpointDataQueue.next();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot(
new ArrowFlightJdbcVectorSchemaRootResultSet(null, state, signature, resultSetMetaData,
timeZone, null);

resultSet.execute(vectorSchemaRoot);
resultSet.populateData(vectorSchemaRoot);
return resultSet;
}

Expand All @@ -92,7 +92,7 @@ protected AvaticaResultSet execute() throws SQLException {
throw new RuntimeException("Can only execute with execute(VectorSchemaRoot)");
}

void execute(final VectorSchemaRoot vectorSchemaRoot) {
void populateData(final VectorSchemaRoot vectorSchemaRoot) {
final List<Field> fields = vectorSchemaRoot.getSchema().getFields();
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields);
signature.columns.clear();
Expand All @@ -102,7 +102,7 @@ void execute(final VectorSchemaRoot vectorSchemaRoot) {
execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns);
}

void execute(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) {
void populateData(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) {
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields());
signature.columns.clear();
signature.columns.addAll(columns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.concurrent.Callable;

import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.flight.CallOption;
Expand All @@ -35,7 +36,6 @@
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
Expand All @@ -58,13 +58,18 @@
*/
public final class ArrowFlightSqlClientHandler implements AutoCloseable {
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);

private final FlightSqlClient sqlClient;
private final Set<CallOption> options = new HashSet<>();
private final Builder builder;

ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient,
final Collection<CallOption> options) {
this.options.addAll(options);
final Builder builder,
final Collection<CallOption> credentialOptions) {
this.options.addAll(builder.options);
this.options.addAll(credentialOptions);
this.sqlClient = Preconditions.checkNotNull(sqlClient);
this.builder = builder;
}

/**
Expand All @@ -75,8 +80,9 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable {
* @return a new {@link ArrowFlightSqlClientHandler}.
*/
public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client,
final Builder builder,
final Collection<CallOption> options) {
return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), options);
return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options);
}

/**
Expand All @@ -95,11 +101,34 @@ private CallOption[] getOptions() {
* @param flightInfo The {@link FlightInfo} instance from which to fetch results.
* @return a {@code FlightStream} of results.
*/
public List<FlightStream> getStreams(final FlightInfo flightInfo) {
return flightInfo.getEndpoints().stream()
.map(FlightEndpoint::getTicket)
.map(ticket -> sqlClient.getStream(ticket, getOptions()))
.collect(Collectors.toList());
public List<Callable<CloseableEndpointStreamPair>> getStreams(final FlightInfo flightInfo) {
final ArrayList<Callable<CloseableEndpointStreamPair>> lazyStreams =
new ArrayList<>(flightInfo.getEndpoints().size());
for (FlightEndpoint endpoint : flightInfo.getEndpoints()) {
lazyStreams.add(() -> {
final CloseableEndpointStreamPair resultPair;
if (endpoint.getLocations().isEmpty()) {
// Create a stream using the current client only and do not close the client at the end.
resultPair = new CloseableEndpointStreamPair(
sqlClient.getStream(endpoint.getTicket(), getOptions()), null);
} else {
final ArrowFlightSqlClientHandler handler = ArrowFlightSqlClientHandler.this.builder.build();
try {
resultPair = new CloseableEndpointStreamPair(
handler.sqlClient.getStream(endpoint.getTicket(), handler.getOptions()), handler.sqlClient);
} catch (Exception ex) {
AutoCloseables.close(handler);
throw ex;
}
}
if (resultPair.getStream().next()) {
return resultPair;
}
resultPair.close();
return null;
});
}
return lazyStreams;
}

/**
Expand Down Expand Up @@ -535,18 +564,21 @@ public Builder withCallOptions(final Collection<CallOption> options) {
* @throws SQLException on error.
*/
public ArrowFlightSqlClientHandler build() throws SQLException {
// Copy middlewares so that the build method doesn't change the state of the builder fields itself.
Set<FlightClientMiddleware.Factory> buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories);
FlightClient client = null;

try {
ClientIncomingAuthHeaderMiddleware.Factory authFactory = null;
// Token should take priority since some apps pass in a username/password even when a token is provided
if (username != null && token == null) {
authFactory =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
withMiddlewareFactories(authFactory);
buildTimeMiddlewareFactories.add(authFactory);
}
final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
withMiddlewareFactories(new ClientCookieMiddleware.Factory());
middlewareFactories.forEach(clientBuilder::intercept);
buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
Location location;
if (useEncryption) {
location = Location.forGrpcTls(host, port);
Expand All @@ -571,17 +603,18 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
}

client = clientBuilder.build();
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
if (authFactory != null) {
options.add(
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, username, password, authFactory, options.toArray(new CallOption[0])));
} else if (token != null) {
options.add(
credentialOptions.add(
ClientAuthenticationUtils.getAuthenticate(
client, new CredentialCallOption(new BearerCredentialWriter(token)), options.toArray(
new CallOption[0])));
}
return ArrowFlightSqlClientHandler.createNewHandler(client, options);
return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions);

} catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) {
final SQLException originalException = new SQLException(e);
Expand Down
Loading

0 comments on commit cfec8d0

Please sign in to comment.