Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
introduce dynamic channel pool
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbernstein2 committed Jan 7, 2022
1 parent 464bb13 commit 4ebbebe
Show file tree
Hide file tree
Showing 8 changed files with 625 additions and 151 deletions.
252 changes: 185 additions & 67 deletions gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import com.google.api.core.InternalApi;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -46,14 +47,12 @@
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import org.threeten.bp.Duration;

/**
Expand All @@ -68,22 +67,16 @@
*/
class ChannelPool extends ManagedChannel {
private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());

// size greater than 1 to allow multiple channel to refresh at the same time
// size not too large so refreshing channels doesn't use too many threads
private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2;
private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50);
private static final double JITTER_PERCENTAGE = 0.15;

private final ChannelPoolSettings settings;
private final ChannelFactory channelFactory;
private final ScheduledExecutorService executor;

private final Object entryWriteLock = new Object();
private final AtomicReference<ImmutableList<Entry>> entries = new AtomicReference<>();
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;
// if set, ChannelPool will manage the life cycle of channelRefreshExecutorService
@Nullable private final ScheduledExecutorService channelRefreshExecutorService;
private final ChannelFactory channelFactory;

private volatile ScheduledFuture<?> nextScheduledRefresh = null;

/**
* Factory method to create a non-refreshing channel pool
Expand All @@ -92,8 +85,9 @@ class ChannelPool extends ManagedChannel {
* @param channelFactory method to create the channels
* @return ChannelPool of non-refreshing channels
*/
@VisibleForTesting
static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IOException {
return new ChannelPool(channelFactory, poolSize, null);
return new ChannelPool(ChannelPoolSettings.staticallySized(poolSize), channelFactory, null);
}

/**
Expand All @@ -103,58 +97,66 @@ static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IO
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @param channelRefreshExecutorService periodically refreshes the channels; its life cycle will
* be managed by ChannelPool
* @param executor used to schedule maintenance tasks like refresh channels and resizing the pool.
* @return ChannelPool of refreshing channels
*/
@VisibleForTesting
static ChannelPool createRefreshing(
int poolSize,
ChannelFactory channelFactory,
ScheduledExecutorService channelRefreshExecutorService)
int poolSize, ChannelFactory channelFactory, ScheduledExecutorService executor)
throws IOException {
return new ChannelPool(channelFactory, poolSize, channelRefreshExecutorService);
return new ChannelPool(
ChannelPoolSettings.staticallySized(poolSize)
.toBuilder()
.setPreemptiveReconnectEnabled(true)
.build(),
channelFactory,
executor);
}

/**
* Factory method to create a refreshing channel pool
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @return ChannelPool of refreshing channels
*/
static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory)
static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFactory)
throws IOException {
return createRefreshing(
poolSize, channelFactory, Executors.newScheduledThreadPool(CHANNEL_REFRESH_EXECUTOR_SIZE));
return new ChannelPool(settings, channelFactory, Executors.newSingleThreadScheduledExecutor());
}

/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param settings options for controling the ChannelPool sizing behavior
* @param channelFactory method to create the channels
* @param poolSize number of channels in the pool
* @param channelRefreshExecutorService periodically refreshes the channels
* @param executor periodically refreshes the channels. Must be single threaded
*/
private ChannelPool(
@InternalApi("VisibleForTesting")
ChannelPool(
ChannelPoolSettings settings,
ChannelFactory channelFactory,
int poolSize,
@Nullable ScheduledExecutorService channelRefreshExecutorService)
ScheduledExecutorService executor)
throws IOException {
this.settings = settings;
this.channelFactory = channelFactory;

ImmutableList.Builder<Entry> initialListBuilder = ImmutableList.builder();

for (int i = 0; i < poolSize; i++) {
for (int i = 0; i < settings.getInitialChannelCount(); i++) {
initialListBuilder.add(new Entry(channelFactory.createSingleChannel()));
}

entries.set(initialListBuilder.build());
authority = entries.get().get(0).channel.authority();
this.channelRefreshExecutorService = channelRefreshExecutorService;

if (channelRefreshExecutorService != null) {
nextScheduledRefresh = scheduleNextRefresh();
this.executor = executor;

if (!settings.isStaticSize()) {
executor.scheduleAtFixedRate(
this::resizeSafely,
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
TimeUnit.SECONDS);
}
if (settings.isPreemptiveReconnectEnabled()) {
executor.scheduleAtFixedRate(
this::refreshSafely,
REFRESH_PERIOD.getSeconds(),
REFRESH_PERIOD.getSeconds(),
TimeUnit.SECONDS);
}
}

Expand Down Expand Up @@ -187,12 +189,9 @@ public ManagedChannel shutdown() {
for (Entry entry : localEntries) {
entry.channel.shutdown();
}
if (nextScheduledRefresh != null) {
nextScheduledRefresh.cancel(true);
}
if (channelRefreshExecutorService != null) {
if (executor != null) {
// shutdownNow will cancel scheduled tasks
channelRefreshExecutorService.shutdownNow();
executor.shutdownNow();
}
return this;
}
Expand All @@ -206,7 +205,7 @@ public boolean isShutdown() {
return false;
}
}
return channelRefreshExecutorService == null || channelRefreshExecutorService.isShutdown();
return executor == null || executor.isShutdown();
}

/** {@inheritDoc} */
Expand All @@ -218,7 +217,8 @@ public boolean isTerminated() {
return false;
}
}
return channelRefreshExecutorService == null || channelRefreshExecutorService.isTerminated();

return executor == null || executor.isTerminated();
}

/** {@inheritDoc} */
Expand All @@ -228,11 +228,8 @@ public ManagedChannel shutdownNow() {
for (Entry entry : localEntries) {
entry.channel.shutdownNow();
}
if (nextScheduledRefresh != null) {
nextScheduledRefresh.cancel(true);
}
if (channelRefreshExecutorService != null) {
channelRefreshExecutorService.shutdownNow();
if (executor != null) {
executor.shutdownNow();
}
return this;
}
Expand All @@ -249,25 +246,129 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}
entry.channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
if (channelRefreshExecutorService != null) {
if (executor != null) {
long awaitTimeNanos = endTimeNanos - System.nanoTime();
channelRefreshExecutorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
executor.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
return isTerminated();
}

/** Scheduling loop. */
private ScheduledFuture<?> scheduleNextRefresh() {
long delayPeriod = REFRESH_PERIOD.toMillis();
long jitter = (long) ((Math.random() - 0.5) * JITTER_PERCENTAGE * delayPeriod);
long delay = jitter + delayPeriod;
return channelRefreshExecutorService.schedule(
() -> {
scheduleNextRefresh();
refresh();
},
delay,
TimeUnit.MILLISECONDS);
void resizeSafely() {
try {
synchronized (entryWriteLock) {
resize();
}
} catch (Exception e) {
LOG.log(Level.WARNING, "Failed to resize channel pool", e);
}
}

/**
* Resize the number of channels based on the number of outstanding RPCs.
*
* <p>This method is expected to be called on a fixed interval. On every invocation it will:
*
* <ul>
* <li>Get the maximum number of outstanding RPCs since last invocation
* <li>Determine a valid range of number of channels to handle that many outstanding RPCs
* <li>If the current number of channel falls outside of that range, add or remove at most
* {@link ChannelPoolSettings#MAX_RESIZE_DELTA} to get closer to middle of that range.
* </ul>
*
* <p>Not threadsafe, must be called under the entryWriteLock monitor
*/
void resize() {
List<Entry> localEntries = entries.get();
// Estimate the peak of RPCs in the last interval by summing the peak of RPCs per channel
int actualOutstandingRpcs =
localEntries.stream().mapToInt(Entry::getAndResetMaxOutstanding).sum();

// Number of channels if each channel operated at max capacity
int minChannels =
(int) Math.ceil(actualOutstandingRpcs / (double) settings.getMaxRpcsPerChannel());
// Limit the threshold to absolute range
if (minChannels < settings.getMinChannelCount()) {
minChannels = settings.getMinChannelCount();
}

// Number of channels if each channel operated at minimum capacity
int maxChannels =
(int) Math.ceil(actualOutstandingRpcs / (double) settings.getMinRpcsPerChannel());
// Limit the threshold to absolute range
if (maxChannels > settings.getMaxChannelCount()) {
maxChannels = settings.getMaxChannelCount();
}
if (maxChannels < minChannels) {
maxChannels = minChannels;
}

// If the pool were to be resized, try to aim for the middle of the bound, but limit rate of
// change.
int tentativeTarget = (maxChannels + minChannels) / 2;
int currentSize = localEntries.size();
int delta = tentativeTarget - currentSize;
int dampenedTarget = tentativeTarget;
if (Math.abs(delta) > ChannelPoolSettings.MAX_RESIZE_DELTA) {
dampenedTarget =
currentSize + (int) Math.copySign(ChannelPoolSettings.MAX_RESIZE_DELTA, delta);
}

// Only resize the pool when thresholds are crossed
if (localEntries.size() < minChannels) {
LOG.fine(
String.format(
"Detected throughput peak of %d, expanding channel pool size: %d -> %d.",
actualOutstandingRpcs, currentSize, dampenedTarget));

expand(tentativeTarget);
} else if (localEntries.size() > maxChannels) {
LOG.fine(
String.format(
"Detected throughput drop to %d, shrinking channel pool size: %d -> %d.",
actualOutstandingRpcs, currentSize, dampenedTarget));

shrink(tentativeTarget);
}
}

/** Not threadsafe, must be called under the entryWriteLock monitor */
private void shrink(int desiredSize) {
List<Entry> localEntries = entries.get();
Preconditions.checkState(
localEntries.size() >= desiredSize, "desired size is already smaller than the current");

// Set the new list
entries.set(ImmutableList.copyOf(localEntries.subList(0, desiredSize)));
// clean up removed entries
List<Entry> removed = localEntries.subList(desiredSize, localEntries.size());
removed.forEach(Entry::requestShutdown);
}

/** Not threadsafe, must be called under the entryWriteLock monitor */
private void expand(int desiredSize) {
List<Entry> localEntries = entries.get();
Preconditions.checkState(
localEntries.size() <= desiredSize, "desired size is already bigger than the current");

ImmutableList.Builder<Entry> newEntries = ImmutableList.<Entry>builder().addAll(localEntries);

for (int i = 0; i < desiredSize - localEntries.size(); i++) {
try {
newEntries.add(new Entry(channelFactory.createSingleChannel()));
} catch (IOException e) {
LOG.log(Level.WARNING, "Failed to add channel", e);
}
}

entries.set(newEntries.build());
}

private void refreshSafely() {
try {
refresh();
} catch (Exception e) {
LOG.log(Level.WARNING, "Failed to pre-emptively refresh channnels", e);
}
}

/**
Expand Down Expand Up @@ -340,14 +441,21 @@ Entry getRetainedEntry(int affinity) {
private Entry getEntry(int affinity) {
List<Entry> localEntries = entries.get();

int index = Math.abs(affinity % localEntries.size());
int index = affinity % localEntries.size();
index = Math.abs(index);
// If index is the most negative int, abs(index) is still negative.
if (index < 0) {
index = 0;
}

return localEntries.get(index);
}

/** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */
private static class Entry {
private final ManagedChannel channel;
private final AtomicInteger outstandingRpcs = new AtomicInteger(0);
private final AtomicInteger maxOutstanding = new AtomicInteger();

// Flag that the channel should be closed once all of the outstanding RPC complete.
private final AtomicBoolean shutdownRequested = new AtomicBoolean();
Expand All @@ -358,6 +466,10 @@ private Entry(ManagedChannel channel) {
this.channel = channel;
}

int getAndResetMaxOutstanding() {
return maxOutstanding.getAndSet(outstandingRpcs.get());
}

/**
* Try to increment the outstanding RPC count. The method will return false if the channel is
* closing and the caller should pick a different channel. If the method returned true, the
Expand All @@ -366,7 +478,13 @@ private Entry(ManagedChannel channel) {
*/
private boolean retain() {
// register desire to start RPC
outstandingRpcs.incrementAndGet();
int currentOutstanding = outstandingRpcs.incrementAndGet();

// Rough book keeping
int prevMax = maxOutstanding.get();
if (currentOutstanding > prevMax) {
maxOutstanding.incrementAndGet();
}

// abort if the channel is closing
if (shutdownRequested.get()) {
Expand Down
Loading

0 comments on commit 4ebbebe

Please sign in to comment.