diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java index 0ef55091ff..f7028cb55f 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlClientWithCallOptions.java @@ -284,6 +284,6 @@ private CallOption[] combine(CallOption... options) { final CallOption[] result = new CallOption[connectionOptions.length + options.length]; System.arraycopy(connectionOptions, 0, result, 0, connectionOptions.length); System.arraycopy(options, 0, result, connectionOptions.length, options.length); - return options; + return result; } } diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java index 0bab9c9e72..c079060ec8 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java @@ -84,8 +84,8 @@ public class FlightSqlConnection implements AdbcConnection { this.counter = new AtomicInteger(0); this.quirks = quirks; this.parameters = parameters; - this.client = - new FlightSqlClientWithCallOptions(new FlightSqlClient(createInitialConnection(location))); + FlightSqlClient flightSqlClient = new FlightSqlClient(createInitialConnection(location)); + this.client = new FlightSqlClientWithCallOptions(flightSqlClient, callOptions); this.clientCache = Caffeine.newBuilder() .expireAfterAccess(5, TimeUnit.MINUTES) diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java new file mode 100644 index 0000000000..19c66f8116 --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.drivermanager.AdbcDriverManager; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class HeaderTest { + + private FlightServer.Builder builder; + private FlightServer server; + private Map params; + private AdbcConnection connection; + private BufferAllocator allocator; + private HeaderValidator.Factory headerValidatorFactory; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + headerValidatorFactory = new HeaderValidator.Factory(); + builder = + FlightServer.builder() + .middleware(HeaderValidator.KEY, headerValidatorFactory) + .location(Location.forGrpcInsecure("localhost", 0)) + .producer(new MockFlightSqlProducer()); + params = new HashMap<>(); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection, server, allocator); + connection = null; + server = null; + allocator = null; + } + + @Test + public void testArbitraryHeader() throws Exception { + final String dummyValue = "dummy"; + final String dummyHeaderName = "test-header"; + params.put(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX + dummyHeaderName, dummyValue); + server = builder.build(); + server.start(); + connect(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals(dummyValue, headers.get(dummyHeaderName)); + } + + @Test + public void testCookies() throws Exception { + builder.middleware(CookieMiddleware.KEY, new CookieMiddleware.Factory()); + server = builder.build(); + server.start(); + + params.put(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.getKey(), true); + connect(); + try (ArrowReader reader = connection.getInfo(new int[] {})) { + + } catch (Exception ex) { + // Swallow exceptions from the RPC call. Only interested in tracking metadata. + } + CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1); + assertTrue(secondHeaders.containsKey("cookie")); + } + + @Test + public void testBearerToken() throws Exception { + builder.headerAuthenticator( + new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator((username, password) -> () -> username))); + server = builder.build(); + server.start(); + + params.put(AdbcDriver.PARAM_USERNAME.getKey(), "dummy_user"); + params.put(AdbcDriver.PARAM_PASSWORD.getKey(), "dummy_password"); + connect(); + try (ArrowReader reader = connection.getInfo(new int[] {})) { + + } catch (Exception ex) { + // Swallow exceptions from the RPC call. Only interested in tracking metadata. + } + CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1); + assertTrue(secondHeaders.get("authorization").contains("Bearer")); + } + + @Test + public void testUnauthenticated() throws Exception { + builder.headerAuthenticator( + new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator((username, password) -> () -> username))); + server = builder.build(); + server.start(); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.UNAUTHENTICATED, adbcException.getStatus()); + } + + static class CookieMiddleware implements FlightServerMiddleware { + + public static final Key KEY = Key.of("CookieMiddleware"); + + @Override + public void onBeforeSendingHeaders(CallHeaders callHeaders) { + callHeaders.insert("set-cookie", "test=test_val"); + } + + @Override + public void onCallCompleted(CallStatus callStatus) {} + + @Override + public void onCallErrored(Throwable throwable) {} + + public static class Factory implements FlightServerMiddleware.Factory { + + @Override + public CookieMiddleware onCallStarted( + CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) { + return new CookieMiddleware(); + } + } + } + + private void connect() throws Exception { + int port = server.getPort(); + String uri = String.format("grpc+tcp://%s:%d", "localhost", port); + params.put(AdbcDriver.PARAM_URI.getKey(), uri); + AdbcDatabase db = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + connection = db.connect(); + } +} diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java index 2131a56db6..f543dd9958 100644 --- a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/HeaderValidator.java @@ -16,13 +16,17 @@ */ package org.apache.arrow.adbc.driver.flightsql; +import java.util.ArrayList; import org.apache.arrow.flight.CallHeaders; import org.apache.arrow.flight.CallInfo; import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightCallHeaders; import org.apache.arrow.flight.FlightServerMiddleware; import org.apache.arrow.flight.RequestContext; public class HeaderValidator implements FlightServerMiddleware { + public static final Key KEY = Key.of("HeaderValidator"); + @Override public void onBeforeSendingHeaders(CallHeaders callHeaders) {} @@ -34,10 +38,28 @@ public void onCallErrored(Throwable throwable) {} public static class Factory implements FlightServerMiddleware.Factory { + private final ArrayList headersReceived = new ArrayList<>(); + @Override public HeaderValidator onCallStarted( CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) { - return null; + CallHeaders cloneHeaders = cloneHeaders(callHeaders); + headersReceived.add(cloneHeaders); + return new HeaderValidator(); + } + + public CallHeaders getHeadersReceivedAtRequest(int request) { + return cloneHeaders(headersReceived.get(request)); + } + + private static CallHeaders cloneHeaders(CallHeaders headers) { + FlightCallHeaders cloneHeaders = new FlightCallHeaders(); + for (String key : headers.keys()) { + for (String value : headers.getAll(key)) { + cloneHeaders.insert(key, value); + } + } + return cloneHeaders; } } }