Skip to content

Commit

Permalink
Add tests and bugfixes
Browse files Browse the repository at this point in the history
Add tests for arbitrary headers, cookies, auth.
Fix bugs retaining the bearer token.
  • Loading branch information
jduo committed Feb 1, 2024
1 parent 96c9632 commit 676be98
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,6 @@ private CallOption[] combine(CallOption... options) {
final CallOption[] result = new CallOption[connectionOptions.length + options.length];
System.arraycopy(connectionOptions, 0, result, 0, connectionOptions.length);
System.arraycopy(options, 0, result, connectionOptions.length, options.length);
return options;
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ public class FlightSqlConnection implements AdbcConnection {
this.counter = new AtomicInteger(0);
this.quirks = quirks;
this.parameters = parameters;
this.client =
new FlightSqlClientWithCallOptions(new FlightSqlClient(createInitialConnection(location)));
FlightSqlClient flightSqlClient = new FlightSqlClient(createInitialConnection(location));
this.client = new FlightSqlClientWithCallOptions(flightSqlClient, callOptions);
this.clientCache =
Caffeine.newBuilder()
.expireAfterAccess(5, TimeUnit.MINUTES)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.adbc.driver.flightsql;

import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcDatabase;
import org.apache.arrow.adbc.core.AdbcDriver;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.drivermanager.AdbcDriverManager;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class HeaderTest {

private FlightServer.Builder builder;
private FlightServer server;
private Map<String, Object> params;
private AdbcConnection connection;
private BufferAllocator allocator;
private HeaderValidator.Factory headerValidatorFactory;

@BeforeEach
public void setUp() {
allocator = new RootAllocator(Long.MAX_VALUE);
headerValidatorFactory = new HeaderValidator.Factory();
builder =
FlightServer.builder()
.middleware(HeaderValidator.KEY, headerValidatorFactory)
.location(Location.forGrpcInsecure("localhost", 0))
.producer(new MockFlightSqlProducer());
params = new HashMap<>();
}

@AfterEach
public void tearDown() throws Exception {
AutoCloseables.close(connection, server, allocator);
connection = null;
server = null;
allocator = null;
}

@Test
public void testArbitraryHeader() throws Exception {
final String dummyValue = "dummy";
final String dummyHeaderName = "test-header";
params.put(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX + dummyHeaderName, dummyValue);
server = builder.build();
server.start();
connect();

CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0);
assertEquals(dummyValue, headers.get(dummyHeaderName));
}

@Test
public void testCookies() throws Exception {
builder.middleware(CookieMiddleware.KEY, new CookieMiddleware.Factory());
server = builder.build();
server.start();

params.put(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.getKey(), true);
connect();
try (ArrowReader reader = connection.getInfo(new int[] {})) {

} catch (Exception ex) {
// Swallow exceptions from the RPC call. Only interested in tracking metadata.
}
CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1);
assertTrue(secondHeaders.containsKey("cookie"));
}

@Test
public void testBearerToken() throws Exception {
builder.headerAuthenticator(
new GeneratedBearerTokenAuthenticator(
new BasicCallHeaderAuthenticator((username, password) -> () -> username)));
server = builder.build();
server.start();

params.put(AdbcDriver.PARAM_USERNAME.getKey(), "dummy_user");
params.put(AdbcDriver.PARAM_PASSWORD.getKey(), "dummy_password");
connect();
try (ArrowReader reader = connection.getInfo(new int[] {})) {

} catch (Exception ex) {
// Swallow exceptions from the RPC call. Only interested in tracking metadata.
}
CallHeaders secondHeaders = headerValidatorFactory.getHeadersReceivedAtRequest(1);
assertTrue(secondHeaders.get("authorization").contains("Bearer"));
}

@Test
public void testUnauthenticated() throws Exception {
builder.headerAuthenticator(
new GeneratedBearerTokenAuthenticator(
new BasicCallHeaderAuthenticator((username, password) -> () -> username)));
server = builder.build();
server.start();

AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
assertEquals(AdbcStatusCode.UNAUTHENTICATED, adbcException.getStatus());
}

static class CookieMiddleware implements FlightServerMiddleware {

public static final Key<CookieMiddleware> KEY = Key.of("CookieMiddleware");

@Override
public void onBeforeSendingHeaders(CallHeaders callHeaders) {
callHeaders.insert("set-cookie", "test=test_val");
}

@Override
public void onCallCompleted(CallStatus callStatus) {}

@Override
public void onCallErrored(Throwable throwable) {}

public static class Factory implements FlightServerMiddleware.Factory<CookieMiddleware> {

@Override
public CookieMiddleware onCallStarted(
CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
return new CookieMiddleware();
}
}
}

private void connect() throws Exception {
int port = server.getPort();
String uri = String.format("grpc+tcp://%s:%d", "localhost", port);
params.put(AdbcDriver.PARAM_URI.getKey(), uri);
AdbcDatabase db =
AdbcDriverManager.getInstance()
.connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
connection = db.connect();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
*/
package org.apache.arrow.adbc.driver.flightsql;

import java.util.ArrayList;
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightCallHeaders;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.RequestContext;

public class HeaderValidator implements FlightServerMiddleware {
public static final Key<HeaderValidator> KEY = Key.of("HeaderValidator");

@Override
public void onBeforeSendingHeaders(CallHeaders callHeaders) {}

Expand All @@ -34,10 +38,28 @@ public void onCallErrored(Throwable throwable) {}

public static class Factory implements FlightServerMiddleware.Factory<HeaderValidator> {

private final ArrayList<CallHeaders> headersReceived = new ArrayList<>();

@Override
public HeaderValidator onCallStarted(
CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
return null;
CallHeaders cloneHeaders = cloneHeaders(callHeaders);
headersReceived.add(cloneHeaders);
return new HeaderValidator();
}

public CallHeaders getHeadersReceivedAtRequest(int request) {
return cloneHeaders(headersReceived.get(request));
}

private static CallHeaders cloneHeaders(CallHeaders headers) {
FlightCallHeaders cloneHeaders = new FlightCallHeaders();
for (String key : headers.keys()) {
for (String value : headers.getAll(key)) {
cloneHeaders.insert(key, value);
}
}
return cloneHeaders;
}
}
}

0 comments on commit 676be98

Please sign in to comment.