Skip to content

Commit

Permalink
Allow supplying a Rediscovery implementation (#1350)
Browse files Browse the repository at this point in the history
* Allow supplying a Rediscovery implementation

This is only for internal purposes.

* Add rediscoverySupplier tests to DriverFactory

* Add null protection
  • Loading branch information
injectives authored Dec 13, 2022
1 parent 7519388 commit 5417654
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 70 deletions.
62 changes: 43 additions & 19 deletions driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.URI;
import java.util.Objects;
import java.util.function.Supplier;
import org.neo4j.driver.AuthToken;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Config;
Expand All @@ -39,11 +41,13 @@
import org.neo4j.driver.internal.async.connection.ChannelConnectorImpl;
import org.neo4j.driver.internal.async.pool.ConnectionPoolImpl;
import org.neo4j.driver.internal.async.pool.PoolSettings;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cluster.RoutingProcedureClusterCompositionProvider;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.cluster.loadbalancing.LeastConnectedLoadBalancingStrategy;
import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer;
import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancingStrategy;
import org.neo4j.driver.internal.logging.NettyLogging;
import org.neo4j.driver.internal.metrics.DevNullMetricsProvider;
import org.neo4j.driver.internal.metrics.InternalMetricsProvider;
Expand All @@ -70,7 +74,7 @@ public final Driver newInstance(
RetrySettings retrySettings,
Config config,
SecurityPlan securityPlan) {
return newInstance(uri, authToken, routingSettings, retrySettings, config, null, securityPlan);
return newInstance(uri, authToken, routingSettings, retrySettings, config, null, securityPlan, null);
}

public final Driver newInstance(
Expand All @@ -80,7 +84,8 @@ public final Driver newInstance(
RetrySettings retrySettings,
Config config,
EventLoopGroup eventLoopGroup,
SecurityPlan securityPlan) {
SecurityPlan securityPlan,
Supplier<Rediscovery> rediscoverySupplier) {
Bootstrap bootstrap;
boolean ownsEventLoopGroup;
if (eventLoopGroup == null) {
Expand Down Expand Up @@ -119,6 +124,7 @@ public final Driver newInstance(
newRoutingSettings,
retryLogic,
metricsProvider,
rediscoverySupplier,
config);
}

Expand Down Expand Up @@ -185,6 +191,7 @@ private InternalDriver createDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
try {
String scheme = uri.getScheme().toLowerCase();
Expand All @@ -198,6 +205,7 @@ private InternalDriver createDriver(
routingSettings,
retryLogic,
metricsProvider,
rediscoverySupplier,
config);
} else {
assertNoRoutingContext(uri, routingSettings);
Expand Down Expand Up @@ -243,9 +251,10 @@ protected InternalDriver createRoutingDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
ConnectionProvider connectionProvider =
createLoadBalancer(address, connectionPool, eventExecutorGroup, config, routingSettings);
ConnectionProvider connectionProvider = createLoadBalancer(
address, connectionPool, eventExecutorGroup, config, routingSettings, rediscoverySupplier);
SessionFactory sessionFactory = createSessionFactory(connectionProvider, retryLogic, config);
InternalDriver driver = createDriver(securityPlan, sessionFactory, metricsProvider, config);
Logger log = config.logging().getLog(getClass());
Expand Down Expand Up @@ -273,24 +282,41 @@ protected LoadBalancer createLoadBalancer(
ConnectionPool connectionPool,
EventExecutorGroup eventExecutorGroup,
Config config,
RoutingSettings routingSettings) {
LoadBalancingStrategy loadBalancingStrategy =
new LeastConnectedLoadBalancingStrategy(connectionPool, config.logging());
ServerAddressResolver resolver = createResolver(config);
LoadBalancer loadBalancer = new LoadBalancer(
address,
routingSettings,
RoutingSettings routingSettings,
Supplier<Rediscovery> rediscoverySupplier) {
var loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(connectionPool, config.logging());
var resolver = createResolver(config);
var domainNameResolver = Objects.requireNonNull(getDomainNameResolver(), "domainNameResolver must not be null");
var clock = createClock();
var logging = config.logging();
if (rediscoverySupplier == null) {
rediscoverySupplier =
() -> createRediscovery(address, resolver, routingSettings, clock, logging, domainNameResolver);
}
var loadBalancer = new LoadBalancer(
connectionPool,
eventExecutorGroup,
createClock(),
config.logging(),
rediscoverySupplier.get(),
routingSettings,
loadBalancingStrategy,
resolver,
getDomainNameResolver());
eventExecutorGroup,
clock,
logging);
handleNewLoadBalancer(loadBalancer);
return loadBalancer;
}

protected Rediscovery createRediscovery(
BoltServerAddress initialRouter,
ServerAddressResolver resolver,
RoutingSettings settings,
Clock clock,
Logging logging,
DomainNameResolver domainNameResolver) {
var clusterCompositionProvider =
new RoutingProcedureClusterCompositionProvider(clock, settings.routingContext());
return new RediscoveryImpl(initialRouter, clusterCompositionProvider, resolver, logging, domainNameResolver);
}

/**
* Handles new {@link LoadBalancer} instance.
* <p>
Expand All @@ -307,8 +333,6 @@ private static ServerAddressResolver createResolver(Config config) {

/**
* Creates new {@link Clock}.
* <p>
* <b>This method is protected only for testing</b>
*/
protected Clock createClock() {
return Clock.SYSTEM;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.neo4j.driver.internal.cluster;

import static java.util.Objects.requireNonNull;
import static org.neo4j.driver.internal.async.ConnectionContext.PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER;

import java.util.HashMap;
Expand Down Expand Up @@ -72,6 +73,7 @@ public RoutingTableRegistryImpl(
ConnectionPool connectionPool,
Rediscovery rediscovery,
Logging logging) {
requireNonNull(rediscovery, "rediscovery must not be null");
this.factory = factory;
this.routingTableHandlers = routingTableHandlers;
this.principalToDatabaseNameStage = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.exceptions.SessionExpiredException;
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.DomainNameResolver;
import org.neo4j.driver.internal.async.ConnectionContext;
import org.neo4j.driver.internal.async.connection.RoutingConnection;
import org.neo4j.driver.internal.cluster.ClusterCompositionProvider;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.cluster.RoutingProcedureClusterCompositionProvider;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.cluster.RoutingTable;
import org.neo4j.driver.internal.cluster.RoutingTableRegistry;
Expand All @@ -56,7 +52,6 @@
import org.neo4j.driver.internal.spi.ConnectionProvider;
import org.neo4j.driver.internal.util.Clock;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.net.ServerAddressResolver;

public class LoadBalancer implements ConnectionProvider {
private static final String CONNECTION_ACQUISITION_COMPLETION_FAILURE_MESSAGE =
Expand All @@ -73,27 +68,6 @@ public class LoadBalancer implements ConnectionProvider {
private final Rediscovery rediscovery;

public LoadBalancer(
BoltServerAddress initialRouter,
RoutingSettings settings,
ConnectionPool connectionPool,
EventExecutorGroup eventExecutorGroup,
Clock clock,
Logging logging,
LoadBalancingStrategy loadBalancingStrategy,
ServerAddressResolver resolver,
DomainNameResolver domainNameResolver) {
this(
connectionPool,
createRediscovery(
initialRouter, resolver, settings, clock, logging, requireNonNull(domainNameResolver)),
settings,
loadBalancingStrategy,
eventExecutorGroup,
clock,
logging);
}

private LoadBalancer(
ConnectionPool connectionPool,
Rediscovery rediscovery,
RoutingSettings settings,
Expand All @@ -117,6 +91,7 @@ private LoadBalancer(
LoadBalancingStrategy loadBalancingStrategy,
EventExecutorGroup eventExecutorGroup,
Logging logging) {
requireNonNull(rediscovery, "rediscovery must not be null");
this.connectionPool = connectionPool;
this.routingTables = routingTables;
this.rediscovery = rediscovery;
Expand Down Expand Up @@ -281,19 +256,14 @@ private static RoutingTableRegistry createRoutingTables(
connectionPool, rediscovery, clock, logging, settings.routingTablePurgeDelayMs());
}

private static Rediscovery createRediscovery(
BoltServerAddress initialRouter,
ServerAddressResolver resolver,
RoutingSettings settings,
Clock clock,
Logging logging,
DomainNameResolver domainNameResolver) {
ClusterCompositionProvider clusterCompositionProvider =
new RoutingProcedureClusterCompositionProvider(clock, settings.routingContext());
return new RediscoveryImpl(initialRouter, clusterCompositionProvider, resolver, logging, domainNameResolver);
}

private static RuntimeException unknownMode(AccessMode mode) {
return new IllegalArgumentException("Mode '" + mode + "' is not supported");
}

/**
* <b>This method is only for testing</b>
*/
public Rediscovery getRediscovery() {
return rediscovery;
}
}
3 changes: 3 additions & 0 deletions driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import java.net.URI;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.DriverFactory;
import org.neo4j.driver.internal.InternalDriver;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.metrics.MetricsProvider;
import org.neo4j.driver.internal.retry.RetryLogic;
Expand Down Expand Up @@ -147,6 +149,7 @@ protected InternalDriver createRoutingDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
return driverIterator.next();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ private Driver createDriver(EventLoopGroup eventLoopGroup) {
RetrySettings.DEFAULT,
Config.defaultConfig(),
eventLoopGroup,
SecurityPlanImpl.insecure());
SecurityPlanImpl.insecure(),
null);
}

private void testConnection(Driver driver) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ void testCustomSecurityPlanUsed() {
RetrySettings.DEFAULT,
Config.defaultConfig(),
null,
securityPlan);
securityPlan,
null);

assertFalse(driverFactory.capturedSecurityPlans.isEmpty());
assertTrue(driverFactory.capturedSecurityPlans.stream().allMatch(capturePlan -> capturePlan == securityPlan));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.junit.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand All @@ -41,6 +45,7 @@
import io.netty.bootstrap.Bootstrap;
import io.netty.util.concurrent.EventExecutorGroup;
import java.net.URI;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -55,6 +60,8 @@
import org.neo4j.driver.internal.async.LeakLoggingNetworkSession;
import org.neo4j.driver.internal.async.NetworkSession;
import org.neo4j.driver.internal.async.connection.BootstrapFactory;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer;
Expand Down Expand Up @@ -191,6 +198,61 @@ void shouldCreateAppropriateDriverType(String uri) {
}
}

@Test
void shouldUseBuiltInRediscoveryByDefault() {
// GIVEN
var driverFactory = new DriverFactory();
var securityPlan =
new SecuritySettings.SecuritySettingsBuilder().build().createSecurityPlan("neo4j");

// WHEN
var driver = driverFactory.newInstance(
URI.create("neo4j://localhost:7687"),
AuthTokens.none(),
RoutingSettings.DEFAULT,
RetrySettings.DEFAULT,
Config.defaultConfig(),
null,
securityPlan,
null);

// THEN
var sessionFactory = ((InternalDriver) driver).getSessionFactory();
var connectionProvider = ((SessionFactoryImpl) sessionFactory).getConnectionProvider();
var rediscovery = ((LoadBalancer) connectionProvider).getRediscovery();
assertTrue(rediscovery instanceof RediscoveryImpl);
}

@Test
void shouldUseSuppliedRediscovery() {
// GIVEN
var driverFactory = new DriverFactory();
var securityPlan =
new SecuritySettings.SecuritySettingsBuilder().build().createSecurityPlan("neo4j");
@SuppressWarnings("unchecked")
Supplier<Rediscovery> rediscoverySupplier = mock(Supplier.class);
var rediscovery = mock(Rediscovery.class);
given(rediscoverySupplier.get()).willReturn(rediscovery);

// WHEN
var driver = driverFactory.newInstance(
URI.create("neo4j://localhost:7687"),
AuthTokens.none(),
RoutingSettings.DEFAULT,
RetrySettings.DEFAULT,
Config.defaultConfig(),
null,
securityPlan,
rediscoverySupplier);

// THEN
var sessionFactory = ((InternalDriver) driver).getSessionFactory();
var connectionProvider = ((SessionFactoryImpl) sessionFactory).getConnectionProvider();
var actualRediscovery = ((LoadBalancer) connectionProvider).getRediscovery();
then(rediscoverySupplier).should().get();
assertEquals(rediscovery, actualRediscovery);
}

private Driver createDriver(String uri, DriverFactory driverFactory) {
return createDriver(uri, driverFactory, defaultConfig());
}
Expand Down Expand Up @@ -239,6 +301,7 @@ protected InternalDriver createRoutingDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
throw new UnsupportedOperationException("Can't create routing driver");
}
Expand Down Expand Up @@ -276,7 +339,8 @@ protected LoadBalancer createLoadBalancer(
ConnectionPool connectionPool,
EventExecutorGroup eventExecutorGroup,
Config config,
RoutingSettings routingSettings) {
RoutingSettings routingSettings,
Supplier<Rediscovery> rediscoverySupplier) {
return null;
}

Expand Down
Loading

0 comments on commit 5417654

Please sign in to comment.