diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 9a53f9fcafdd2..2e42cf0166b06 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -17,7 +17,7 @@ package org.apache.arrow.driver.jdbc; -import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue; +import static org.apache.arrow.driver.jdbc.utils.FlightEndpointDataQueue.createNewQueue; import java.sql.ResultSet; import java.sql.ResultSetMetaData; @@ -26,7 +26,8 @@ import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue; +import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair; +import org.apache.arrow.driver.jdbc.utils.FlightEndpointDataQueue; import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; @@ -47,8 +48,8 @@ public final class ArrowFlightJdbcFlightStreamResultSet extends ArrowFlightJdbcVectorSchemaRootResultSet { private final ArrowFlightConnection connection; - private FlightStream currentFlightStream; - private FlightStreamQueue flightStreamQueue; + private CloseableEndpointStreamPair currentEndpointData; + private FlightEndpointDataQueue flightEndpointDataQueue; private VectorSchemaRootTransformer transformer; private VectorSchemaRoot currentVectorSchemaRoot; @@ -102,20 +103,20 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( resultSet.transformer = transformer; - resultSet.execute(flightInfo); + resultSet.populateData(flightInfo); return resultSet; } private void loadNewQueue() { - Optional.ofNullable(flightStreamQueue).ifPresent(AutoCloseables::closeNoChecked); - flightStreamQueue = createNewQueue(connection.getExecutorService()); + Optional.ofNullable(flightEndpointDataQueue).ifPresent(AutoCloseables::closeNoChecked); + flightEndpointDataQueue = createNewQueue(connection.getExecutorService()); } private void loadNewFlightStream() throws SQLException { - if (currentFlightStream != null) { - AutoCloseables.closeNoChecked(currentFlightStream); + if (currentEndpointData != null) { + AutoCloseables.closeNoChecked(currentEndpointData); } - this.currentFlightStream = getNextFlightStream(true); + this.currentEndpointData = getNextEndpointStream(true); } @Override @@ -124,24 +125,24 @@ protected AvaticaResultSet execute() throws SQLException { if (flightInfo != null) { schema = flightInfo.getSchemaOptional().orElse(null); - execute(flightInfo); + populateData(flightInfo); } return this; } - private void execute(final FlightInfo flightInfo) throws SQLException { + private void populateData(final FlightInfo flightInfo) throws SQLException { loadNewQueue(); - flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); + flightEndpointDataQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); loadNewFlightStream(); // Ownership of the root will be passed onto the cursor. - if (currentFlightStream != null) { - executeForCurrentFlightStream(); + if (currentEndpointData != null) { + populateDataForCurrentFlightStream(); } } - private void executeForCurrentFlightStream() throws SQLException { - final VectorSchemaRoot originalRoot = currentFlightStream.getRoot(); + private void populateDataForCurrentFlightStream() throws SQLException { + final VectorSchemaRoot originalRoot = currentEndpointData.getStream().getRoot(); if (transformer != null) { try { @@ -154,9 +155,9 @@ private void executeForCurrentFlightStream() throws SQLException { } if (schema != null) { - execute(currentVectorSchemaRoot, schema); + populateData(currentVectorSchemaRoot, schema); } else { - execute(currentVectorSchemaRoot); + populateData(currentVectorSchemaRoot); } } @@ -179,20 +180,20 @@ public boolean next() throws SQLException { return true; } - if (currentFlightStream != null) { - currentFlightStream.getRoot().clear(); - if (currentFlightStream.next()) { - executeForCurrentFlightStream(); + if (currentEndpointData != null) { + currentEndpointData.getStream().getRoot().clear(); + if (currentEndpointData.getStream().next()) { + populateDataForCurrentFlightStream(); continue; } - flightStreamQueue.enqueue(currentFlightStream); + flightEndpointDataQueue.enqueue(currentEndpointData); } - currentFlightStream = getNextFlightStream(false); + currentEndpointData = getNextEndpointStream(false); - if (currentFlightStream != null) { - executeForCurrentFlightStream(); + if (currentEndpointData != null) { + populateDataForCurrentFlightStream(); continue; } @@ -207,14 +208,14 @@ public boolean next() throws SQLException { @Override protected void cancel() { super.cancel(); - final FlightStream currentFlightStream = this.currentFlightStream; - if (currentFlightStream != null) { - currentFlightStream.cancel("Cancel", null); + final CloseableEndpointStreamPair currentEndpoint = this.currentEndpointData; + if (currentEndpoint != null) { + currentEndpoint.getStream().cancel("Cancel", null); } - if (flightStreamQueue != null) { + if (flightEndpointDataQueue != null) { try { - flightStreamQueue.close(); + flightEndpointDataQueue.close(); } catch (final Exception e) { throw new RuntimeException(e); } @@ -224,13 +225,14 @@ protected void cancel() { @Override public synchronized void close() { try { - if (flightStreamQueue != null) { + if (flightEndpointDataQueue != null) { // flightStreamQueue should close currentFlightStream internally - flightStreamQueue.close(); - } else if (currentFlightStream != null) { + flightEndpointDataQueue.close(); + } else if (currentEndpointData != null) { // close is only called for currentFlightStream if there's no queue - currentFlightStream.close(); + currentEndpointData.close(); } + } catch (final Exception e) { throw new RuntimeException(e); } finally { @@ -238,13 +240,13 @@ public synchronized void close() { } } - private FlightStream getNextFlightStream(final boolean isExecution) throws SQLException { - if (isExecution) { + private CloseableEndpointStreamPair getNextEndpointStream(final boolean canTimeout) throws SQLException { + if (canTimeout) { final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0; return statementTimeout != 0 ? - flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next(); + flightEndpointDataQueue.next(statementTimeout, TimeUnit.SECONDS) : flightEndpointDataQueue.next(); } else { - return flightStreamQueue.next(); + return flightEndpointDataQueue.next(); } } } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java index 9e377e51decc9..20a2af6a84aa4 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -83,7 +83,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( new ArrowFlightJdbcVectorSchemaRootResultSet(null, state, signature, resultSetMetaData, timeZone, null); - resultSet.execute(vectorSchemaRoot); + resultSet.populateData(vectorSchemaRoot); return resultSet; } @@ -92,7 +92,7 @@ protected AvaticaResultSet execute() throws SQLException { throw new RuntimeException("Can only execute with execute(VectorSchemaRoot)"); } - void execute(final VectorSchemaRoot vectorSchemaRoot) { + void populateData(final VectorSchemaRoot vectorSchemaRoot) { final List fields = vectorSchemaRoot.getSchema().getFields(); final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields); signature.columns.clear(); @@ -102,7 +102,7 @@ void execute(final VectorSchemaRoot vectorSchemaRoot) { execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns); } - void execute(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) { + void populateData(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) { final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields()); signature.columns.clear(); signature.columns.addAll(columns); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index bb1d524aca008..38e5a9bb362d2 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -18,14 +18,15 @@ package org.apache.arrow.driver.jdbc.client; import java.io.IOException; +import java.net.URI; import java.security.GeneralSecurityException; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; import org.apache.arrow.flight.CallOption; @@ -35,8 +36,8 @@ import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.LocationSchemes; import org.apache.arrow.flight.auth2.BearerCredentialWriter; import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; @@ -58,13 +59,18 @@ */ public final class ArrowFlightSqlClientHandler implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class); + private final FlightSqlClient sqlClient; private final Set options = new HashSet<>(); + private final Builder builder; ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient, - final Collection options) { - this.options.addAll(options); + final Builder builder, + final Collection credentialOptions) { + this.options.addAll(builder.options); + this.options.addAll(credentialOptions); this.sqlClient = Preconditions.checkNotNull(sqlClient); + this.builder = builder; } /** @@ -75,8 +81,9 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable { * @return a new {@link ArrowFlightSqlClientHandler}. */ public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client, + final Builder builder, final Collection options) { - return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), options); + return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options); } /** @@ -95,11 +102,55 @@ private CallOption[] getOptions() { * @param flightInfo The {@link FlightInfo} instance from which to fetch results. * @return a {@code FlightStream} of results. */ - public List getStreams(final FlightInfo flightInfo) { - return flightInfo.getEndpoints().stream() - .map(FlightEndpoint::getTicket) - .map(ticket -> sqlClient.getStream(ticket, getOptions())) - .collect(Collectors.toList()); + public List getStreams(final FlightInfo flightInfo) throws SQLException { + final ArrayList endpoints = + new ArrayList<>(flightInfo.getEndpoints().size()); + + try { + for (FlightEndpoint endpoint : flightInfo.getEndpoints()) { + if (endpoint.getLocations().isEmpty()) { + // Create a stream using the current client only and do not close the client at the end. + endpoints.add(new CloseableEndpointStreamPair( + sqlClient.getStream(endpoint.getTicket(), getOptions()), null)); + } else { + // Clone the builder and then set the new endpoint on it. + // GH-38573: This code currently only tries the first Location and treats a failure as fatal. + // This should be changed to try other Locations that are available. + + // GH-38574: Currently a new FlightClient will be made for each partition that returns a non-empty Location + // then disposed of. It may be better to cache clients because a server may report the same Locations. + // It would also be good to identify when the reported location is the same as the original connection's + // Location and skip creating a FlightClient in that scenario. + final URI endpointUri = endpoint.getLocations().get(0).getUri(); + final Builder builderForEndpoint = new Builder(ArrowFlightSqlClientHandler.this.builder) + .withHost(endpointUri.getHost()) + .withPort(endpointUri.getPort()) + .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS)); + + final ArrowFlightSqlClientHandler endpointHandler = builderForEndpoint.build(); + try { + endpoints.add(new CloseableEndpointStreamPair( + endpointHandler.sqlClient.getStream(endpoint.getTicket(), + endpointHandler.getOptions()), endpointHandler.sqlClient)); + } catch (Exception ex) { + AutoCloseables.close(endpointHandler); + throw ex; + } + } + } + } catch (Exception outerException) { + try { + AutoCloseables.close(endpoints); + } catch (Exception innerEx) { + outerException.addSuppressed(innerEx); + } + + if (outerException instanceof SQLException) { + throw (SQLException) outerException; + } + throw new SQLException(outerException); + } + return endpoints; } /** @@ -364,6 +415,31 @@ public static final class Builder { private boolean useSystemTrustStore; private BufferAllocator allocator; + public Builder() { + + } + + /** + * Copies the builder. + * + * @param original The builder to base this copy off of. + */ + private Builder(Builder original) { + this.middlewareFactories.addAll(original.middlewareFactories); + this.options.addAll(original.options); + this.host = original.host; + this.port = original.port; + this.username = original.username; + this.password = original.password; + this.trustStorePath = original.trustStorePath; + this.trustStorePassword = original.trustStorePassword; + this.token = original.token; + this.useEncryption = original.useEncryption; + this.disableCertificateVerification = original.disableCertificateVerification; + this.useSystemTrustStore = original.useSystemTrustStore; + this.allocator = original.allocator; + } + /** * Sets the host for this handler. * @@ -535,18 +611,22 @@ public Builder withCallOptions(final Collection options) { * @throws SQLException on error. */ public ArrowFlightSqlClientHandler build() throws SQLException { + // Copy middlewares so that the build method doesn't change the state of the builder fields itself. + Set buildTimeMiddlewareFactories = new HashSet<>(this.middlewareFactories); FlightClient client = null; + try { ClientIncomingAuthHeaderMiddleware.Factory authFactory = null; // Token should take priority since some apps pass in a username/password even when a token is provided if (username != null && token == null) { authFactory = new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); - withMiddlewareFactories(authFactory); + buildTimeMiddlewareFactories.add(authFactory); } final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); - withMiddlewareFactories(new ClientCookieMiddleware.Factory()); - middlewareFactories.forEach(clientBuilder::intercept); + + buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory()); + buildTimeMiddlewareFactories.forEach(clientBuilder::intercept); Location location; if (useEncryption) { location = Location.forGrpcTls(host, port); @@ -571,17 +651,18 @@ public ArrowFlightSqlClientHandler build() throws SQLException { } client = clientBuilder.build(); + final ArrayList credentialOptions = new ArrayList<>(); if (authFactory != null) { - options.add( + credentialOptions.add( ClientAuthenticationUtils.getAuthenticate( client, username, password, authFactory, options.toArray(new CallOption[0]))); } else if (token != null) { - options.add( + credentialOptions.add( ClientAuthenticationUtils.getAuthenticate( client, new CredentialCallOption(new BearerCredentialWriter(token)), options.toArray( new CallOption[0]))); } - return ArrowFlightSqlClientHandler.createNewHandler(client, options); + return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions); } catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) { final SQLException originalException = new SQLException(e); diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/CloseableEndpointStreamPair.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/CloseableEndpointStreamPair.java new file mode 100644 index 0000000000000..6c37a5b0c626c --- /dev/null +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/CloseableEndpointStreamPair.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; + +/** + * Represents a connection to a {@link org.apache.arrow.flight.FlightEndpoint}. + */ +public class CloseableEndpointStreamPair implements AutoCloseable { + + private final FlightStream stream; + private final FlightSqlClient client; + + public CloseableEndpointStreamPair(FlightStream stream, FlightSqlClient client) { + this.stream = Preconditions.checkNotNull(stream); + this.client = client; + } + + public FlightStream getStream() { + return stream; + } + + @Override + public void close() throws Exception { + AutoCloseables.close(stream, client); + } +} diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueue.java similarity index 73% rename from java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java rename to java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueue.java index e1d770800e40c..71cafd2ec3075 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueue.java @@ -36,6 +36,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; @@ -55,28 +56,28 @@ *
  • Repeat from (3) until next() returns null.
  • * */ -public class FlightStreamQueue implements AutoCloseable { - private static final Logger LOGGER = LoggerFactory.getLogger(FlightStreamQueue.class); - private final CompletionService completionService; - private final Set> futures = synchronizedSet(new HashSet<>()); - private final Set allStreams = synchronizedSet(new HashSet<>()); +public class FlightEndpointDataQueue implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightEndpointDataQueue.class); + private final CompletionService completionService; + private final Set> futures = synchronizedSet(new HashSet<>()); + private final Set endpointsToClose = synchronizedSet(new HashSet<>()); private final AtomicBoolean closed = new AtomicBoolean(); /** * Instantiate a new FlightStreamQueue. */ - protected FlightStreamQueue(final CompletionService executorService) { + protected FlightEndpointDataQueue(final CompletionService executorService) { completionService = checkNotNull(executorService); } /** - * Creates a new {@link FlightStreamQueue} from the provided {@link ExecutorService}. + * Creates a new {@link FlightEndpointDataQueue} from the provided {@link ExecutorService}. * * @param service the service from which to create a new queue. * @return a new queue. */ - public static FlightStreamQueue createNewQueue(final ExecutorService service) { - return new FlightStreamQueue(new ExecutorCompletionService<>(service)); + public static FlightEndpointDataQueue createNewQueue(final ExecutorService service) { + return new FlightEndpointDataQueue(new ExecutorCompletionService<>(service)); } /** @@ -92,19 +93,21 @@ public boolean isClosed() { * Auxiliary functional interface for getting ready-to-consume FlightStreams. */ @FunctionalInterface - interface FlightStreamSupplier { - Future get() throws SQLException; + interface EndpointStreamSupplier { + Future get() throws SQLException; } - private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throws SQLException { + private CloseableEndpointStreamPair next(final EndpointStreamSupplier endpointStreamSupplier) throws SQLException { checkOpen(); while (!futures.isEmpty()) { - final Future future = flightStreamSupplier.get(); + final Future future = endpointStreamSupplier.get(); futures.remove(future); try { - final FlightStream stream = future.get(); - if (stream.getRoot().getRowCount() > 0) { - return stream; + final CloseableEndpointStreamPair endpoint = future.get(); + // Get the next FlightStream with content. + // The stream is non-empty. + if (endpoint.getStream().getRoot().getRowCount() > 0) { + return endpoint; } } catch (final ExecutionException | InterruptedException | CancellationException e) { throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); @@ -120,11 +123,11 @@ private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throw * @param timeoutUnit the timeoutValue time unit * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. */ - public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) + public CloseableEndpointStreamPair next(final long timeoutValue, final TimeUnit timeoutUnit) throws SQLException { return next(() -> { try { - final Future future = completionService.poll(timeoutValue, timeoutUnit); + final Future future = completionService.poll(timeoutValue, timeoutUnit); if (future != null) { return future; } @@ -142,7 +145,7 @@ public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) * * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. */ - public FlightStream next() throws SQLException { + public CloseableEndpointStreamPair next() throws SQLException { return next(() -> { try { return completionService.take(); @@ -162,21 +165,21 @@ public synchronized void checkOpen() { /** * Readily adds given {@link FlightStream}s to the queue. */ - public void enqueue(final Collection flightStreams) { - flightStreams.forEach(this::enqueue); + public void enqueue(final Collection endpointRequests) { + endpointRequests.forEach(this::enqueue); } /** * Adds given {@link FlightStream} to the queue. */ - public synchronized void enqueue(final FlightStream flightStream) { - checkNotNull(flightStream); + public synchronized void enqueue(final CloseableEndpointStreamPair endpointRequest) { + checkNotNull(endpointRequest); checkOpen(); - allStreams.add(flightStream); + endpointsToClose.add(endpointRequest); futures.add(completionService.submit(() -> { // `FlightStream#next` will block until new data can be read or stream is over. - flightStream.next(); - return flightStream; + endpointRequest.getStream().next(); + return endpointRequest; })); } @@ -187,14 +190,15 @@ private static boolean isCallStatusCancelled(final Exception e) { @Override public synchronized void close() throws SQLException { - final Set exceptions = new HashSet<>(); if (isClosed()) { return; } + + final Set exceptions = new HashSet<>(); try { - for (final FlightStream flightStream : allStreams) { + for (final CloseableEndpointStreamPair endpointToClose : endpointsToClose) { try { - flightStream.cancel("Cancelling this FlightStream.", null); + endpointToClose.getStream().cancel("Cancelling this FlightStream.", null); } catch (final Exception e) { final String errorMsg = "Failed to cancel a FlightStream."; LOGGER.error(errorMsg, e); @@ -214,9 +218,9 @@ public synchronized void close() throws SQLException { } } }); - for (final FlightStream flightStream : allStreams) { + for (final CloseableEndpointStreamPair endpointToClose : endpointsToClose) { try { - flightStream.close(); + endpointToClose.close(); } catch (final Exception e) { final String errorMsg = "Failed to close a FlightStream."; LOGGER.error(errorMsg, e); @@ -224,7 +228,7 @@ public synchronized void close() throws SQLException { } } } finally { - allStreams.clear(); + endpointsToClose.clear(); futures.clear(); closed.set(true); } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index b3002ec58416e..e2ac100b8dc36 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -19,6 +19,7 @@ import static java.lang.String.format; import static java.util.Collections.synchronizedSet; +import static org.apache.arrow.flight.Location.forGrpcInsecure; import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.anyOf; import static org.hamcrest.CoreMatchers.containsString; @@ -29,16 +30,32 @@ import static org.junit.Assert.fail; import java.sql.Connection; +import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLTimeoutException; import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.apache.arrow.driver.jdbc.utils.PartitionedFlightSqlProducer; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -144,9 +161,10 @@ public void testShouldRunSelectQuerySettingMaxRowLimit() throws Exception { @Test(expected = SQLException.class) public void testShouldThrowExceptionUponAttemptingToExecuteAnInvalidSelectQuery() throws Exception { - Statement statement = connection.createStatement(); - statement.executeQuery("SELECT * FROM SHOULD-FAIL"); - fail(); + try (Statement statement = connection.createStatement(); + ResultSet result = statement.executeQuery("SELECT * FROM SHOULD-FAIL")) { + fail(); + } } /** @@ -200,14 +218,15 @@ public void testColumnCountShouldRemainConsistentForResultSetThroughoutEntireDur */ @Test public void testShouldCloseStatementWhenIsCloseOnCompletion() throws Exception { - Statement statement = connection.createStatement(); - ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { - statement.closeOnCompletion(); + statement.closeOnCompletion(); - resultSetNextUntilDone(resultSet); + resultSetNextUntilDone(resultSet); - collector.checkThat(statement.isClosed(), is(true)); + collector.checkThat(statement.isClosed(), is(true)); + } } /** @@ -368,9 +387,72 @@ public void testFlightStreamsQueryShouldNotTimeout() throws SQLException { final int timeoutValue = 5; try (Statement statement = connection.createStatement()) { statement.setQueryTimeout(timeoutValue); - ResultSet resultSet = statement.executeQuery(query); - CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); - resultSet.close(); + try (ResultSet resultSet = statement.executeQuery(query)) { + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + } + } + } + + @Test + public void testPartitionedFlightServer() throws Exception { + // Arrange + final Schema schema = new Schema( + Arrays.asList(Field.nullablePrimitive("int_column", new ArrowType.Int(32, true)))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot firstPartition = VectorSchemaRoot.create(schema, allocator); + VectorSchemaRoot secondPartition = VectorSchemaRoot.create(schema, allocator)) { + firstPartition.setRowCount(1); + ((IntVector) firstPartition.getVector(0)).set(0, 1); + secondPartition.setRowCount(1); + ((IntVector) secondPartition.getVector(0)).set(0, 2); + + // Construct the data-only nodes first. + FlightProducer firstProducer = new PartitionedFlightSqlProducer.DataOnlyFlightSqlProducer( + new Ticket("first".getBytes()), firstPartition); + FlightProducer secondProducer = new PartitionedFlightSqlProducer.DataOnlyFlightSqlProducer( + new Ticket("second".getBytes()), secondPartition); + + final FlightServer.Builder firstBuilder = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), firstProducer); + + final FlightServer.Builder secondBuilder = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), secondProducer); + + // Run the data-only nodes so that we can get the Locations they are running at. + try (FlightServer firstServer = firstBuilder.build(); + FlightServer secondServer = secondBuilder.build()) { + firstServer.start(); + secondServer.start(); + final FlightEndpoint firstEndpoint = + new FlightEndpoint(new Ticket("first".getBytes()), firstServer.getLocation()); + + final FlightEndpoint secondEndpoint = + new FlightEndpoint(new Ticket("second".getBytes()), secondServer.getLocation()); + + // Finally start the root node. + try (final PartitionedFlightSqlProducer rootProducer = new PartitionedFlightSqlProducer( + schema, firstEndpoint, secondEndpoint); + FlightServer rootServer = FlightServer.builder( + allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = DriverManager.getConnection(String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort())); + Statement newStatement = newConnection.createStatement(); + // Act + ResultSet result = newStatement.executeQuery("Select partitioned_data")) { + List resultData = new ArrayList<>(); + while (result.next()) { + resultData.add(result.getInt(1)); + } + + // Assert + assertEquals(firstPartition.getRowCount() + secondPartition.getRowCount(), resultData.size()); + assertTrue(resultData.contains(((IntVector) firstPartition.getVector(0)).get(0))); + assertTrue(resultData.contains(((IntVector) secondPartition.getVector(0)).get(0))); + } + } } } } diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueueTest.java similarity index 85% rename from java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java rename to java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueueTest.java index b474da55a7f1f..05325faa18ef3 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightEndpointDataQueueTest.java @@ -23,7 +23,7 @@ import java.util.concurrent.CompletionService; -import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -33,20 +33,20 @@ import org.mockito.junit.MockitoJUnitRunner; /** - * Tests for {@link FlightStreamQueue}. + * Tests for {@link FlightEndpointDataQueue}. */ @RunWith(MockitoJUnitRunner.class) -public class FlightStreamQueueTest { +public class FlightEndpointDataQueueTest { @Rule public final ErrorCollector collector = new ErrorCollector(); @Mock - private CompletionService mockedService; - private FlightStreamQueue queue; + private CompletionService mockedService; + private FlightEndpointDataQueue queue; @Before public void setUp() { - queue = new FlightStreamQueue(mockedService); + queue = new FlightEndpointDataQueue(mockedService); } @Test @@ -64,7 +64,7 @@ public void testNextShouldThrowExceptionUponClose() throws Exception { public void testEnqueueShouldThrowExceptionUponClose() throws Exception { queue.close(); ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, - () -> queue.enqueue(mock(FlightStream.class))); + () -> queue.enqueue(mock(CloseableEndpointStreamPair.class))); } @Test diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/PartitionedFlightSqlProducer.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/PartitionedFlightSqlProducer.java new file mode 100644 index 0000000000000..3230ce626fac6 --- /dev/null +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/PartitionedFlightSqlProducer.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static com.google.protobuf.Any.pack; + +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.BasicFlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +public class PartitionedFlightSqlProducer extends BasicFlightSqlProducer { + + /** + * A minimal FlightProducer intended to just serve data when given the correct Ticket. + */ + public static class DataOnlyFlightSqlProducer extends NoOpFlightProducer { + private final Ticket ticket; + private final VectorSchemaRoot data; + + public DataOnlyFlightSqlProducer(Ticket ticket, VectorSchemaRoot data) { + this.ticket = ticket; + this.data = data; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + if (!Arrays.equals(ticket.getBytes(), this.ticket.getBytes())) { + listener.error(CallStatus.INVALID_ARGUMENT.withDescription("Illegal ticket.").toRuntimeException()); + return; + } + + listener.start(data); + listener.putNext(); + listener.completed(); + } + } + + private final List endpoints; + + private final Schema schema; + + public PartitionedFlightSqlProducer(Schema schema, FlightEndpoint... endpoints) { + this.schema = schema; + this.endpoints = Arrays.asList(endpoints); + } + + @Override + protected List determineEndpoints( + T request, FlightDescriptor flightDescriptor, Schema schema) { + return endpoints; + } + + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + final FlightSql.ActionCreatePreparedStatementResult.Builder resultBuilder = + FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.EMPTY); + + final ByteString datasetSchemaBytes = ByteString.copyFrom(schema.serializeAsMessage()); + + resultBuilder.setDatasetSchema(datasetSchemaBytes); + listener.onNext(new Result(pack(resultBuilder.build()).toByteArray())); + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement( + FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { + return FlightInfo.builder(schema, descriptor, endpoints).build(); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return FlightInfo.builder(schema, descriptor, endpoints).build(); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + listener.onCompleted(); + } + + // Note -- getStream() is intentionally not implemented. +}