Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

feat: dynamic channel pool scaled by number of outstanding request #1569

Merged
merged 28 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
158b06d
merge RefreshingManagedChannel & SafeShutdownManagedChannel into Chan…
igorbernstein2 Nov 19, 2021
7e02e4a
migrate SafeShutdownManagedChannel tests
igorbernstein2 Nov 19, 2021
50735bd
migrate old tests and remove RefreshingManagedChannel and SafeShutdow…
igorbernstein2 Nov 19, 2021
3a501e0
fix test
igorbernstein2 Nov 22, 2021
321eb63
Merge branch 'main' into refactor-channel-pool
chanseokoh Nov 30, 2021
2e00f63
Merge branch 'main' into refactor-channel-pool
chanseokoh Dec 7, 2021
6cb0964
address feedback
igorbernstein2 Dec 9, 2021
ec33d83
Merge remote-tracking branch 'igor/refactor-channel-pool' into refact…
igorbernstein2 Dec 9, 2021
8be5066
Merge branch 'main' into refactor-channel-pool
chanseokoh Dec 10, 2021
8f6e39e
fix race condition on refresh()
igorbernstein2 Dec 23, 2021
c93daf4
fix warnings in test
igorbernstein2 Dec 23, 2021
a386cf2
Merge remote-tracking branch 'origin/main' into refactor-channel-pool
igorbernstein2 Dec 23, 2021
9dfa4e9
Merge remote-tracking branch 'igor/refactor-channel-pool' into refact…
igorbernstein2 Dec 23, 2021
0062b0f
Update gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest…
igorbernstein2 Dec 23, 2021
cc0b507
Update gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest…
igorbernstein2 Dec 23, 2021
7bf8b6d
handle race condition
igorbernstein2 Jan 6, 2022
4858e76
Merge branch 'main' into refactor-channel-pool
igorbernstein2 Jan 6, 2022
d2d7830
Update gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
igorbernstein2 Jan 6, 2022
464bb13
Merge branch 'main' into refactor-channel-pool
chanseokoh Jan 6, 2022
4ebbebe
introduce dynamic channel pool
igorbernstein2 Nov 22, 2021
7cc62d0
Merge remote-tracking branch 'origin/main' into dynamic-channel-pool
igorbernstein2 Jan 7, 2022
5f65a37
fix test after broken merge
igorbernstein2 Jan 7, 2022
10c0cf0
format
igorbernstein2 Jan 7, 2022
71f102e
Merge branch 'main' into dynamic-channel-pool
igorbernstein2 Jan 10, 2022
3381e9e
address feedback
igorbernstein2 Feb 11, 2022
774d4ca
Merge remote-tracking branch 'origin/main' into dynamic-channel-pool
igorbernstein2 Feb 11, 2022
71d177b
remove unused import
igorbernstein2 Feb 14, 2022
d82c6b8
inline old factory methods
igorbernstein2 Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 182 additions & 67 deletions gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import com.google.api.core.InternalApi;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -46,14 +47,12 @@
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import org.threeten.bp.Duration;

/**
Expand All @@ -68,22 +67,16 @@
*/
class ChannelPool extends ManagedChannel {
private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());

// size greater than 1 to allow multiple channel to refresh at the same time
// size not too large so refreshing channels doesn't use too many threads
private static final int CHANNEL_REFRESH_EXECUTOR_SIZE = 2;
private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50);
private static final double JITTER_PERCENTAGE = 0.15;

private final ChannelPoolSettings settings;
private final ChannelFactory channelFactory;
private final ScheduledExecutorService executor;

private final Object entryWriteLock = new Object();
private final AtomicReference<ImmutableList<Entry>> entries = new AtomicReference<>();
@VisibleForTesting final AtomicReference<ImmutableList<Entry>> entries = new AtomicReference<>();
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;
// if set, ChannelPool will manage the life cycle of channelRefreshExecutorService
@Nullable private final ScheduledExecutorService channelRefreshExecutorService;
private final ChannelFactory channelFactory;

private volatile ScheduledFuture<?> nextScheduledRefresh = null;

/**
* Factory method to create a non-refreshing channel pool
Expand All @@ -92,8 +85,9 @@ class ChannelPool extends ManagedChannel {
* @param channelFactory method to create the channels
* @return ChannelPool of non-refreshing channels
*/
@VisibleForTesting
static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IOException {
return new ChannelPool(channelFactory, poolSize, null);
return new ChannelPool(ChannelPoolSettings.staticallySized(poolSize), channelFactory, null);
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -103,58 +97,66 @@ static ChannelPool create(int poolSize, ChannelFactory channelFactory) throws IO
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @param channelRefreshExecutorService periodically refreshes the channels; its life cycle will
* be managed by ChannelPool
* @param executor used to schedule maintenance tasks like refresh channels and resizing the pool.
* @return ChannelPool of refreshing channels
*/
@VisibleForTesting
static ChannelPool createRefreshing(
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
int poolSize,
ChannelFactory channelFactory,
ScheduledExecutorService channelRefreshExecutorService)
int poolSize, ChannelFactory channelFactory, ScheduledExecutorService executor)
throws IOException {
return new ChannelPool(channelFactory, poolSize, channelRefreshExecutorService);
return new ChannelPool(
ChannelPoolSettings.staticallySized(poolSize)
.toBuilder()
.setPreemptiveRefreshEnabled(true)
.build(),
channelFactory,
executor);
}

/**
* Factory method to create a refreshing channel pool
*
* @param poolSize number of channels in the pool
* @param channelFactory method to create the channels
* @return ChannelPool of refreshing channels
*/
static ChannelPool createRefreshing(int poolSize, final ChannelFactory channelFactory)
static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFactory)
throws IOException {
return createRefreshing(
poolSize, channelFactory, Executors.newScheduledThreadPool(CHANNEL_REFRESH_EXECUTOR_SIZE));
return new ChannelPool(settings, channelFactory, Executors.newSingleThreadScheduledExecutor());
}

/**
* Initializes the channel pool. Assumes that all channels have the same authority.
*
* @param settings options for controling the ChannelPool sizing behavior
* @param channelFactory method to create the channels
* @param poolSize number of channels in the pool
* @param channelRefreshExecutorService periodically refreshes the channels
* @param executor periodically refreshes the channels
*/
private ChannelPool(
@VisibleForTesting
ChannelPool(
ChannelPoolSettings settings,
ChannelFactory channelFactory,
int poolSize,
@Nullable ScheduledExecutorService channelRefreshExecutorService)
ScheduledExecutorService executor)
throws IOException {
this.settings = settings;
this.channelFactory = channelFactory;

ImmutableList.Builder<Entry> initialListBuilder = ImmutableList.builder();

for (int i = 0; i < poolSize; i++) {
for (int i = 0; i < settings.getInitialChannelCount(); i++) {
initialListBuilder.add(new Entry(channelFactory.createSingleChannel()));
}

entries.set(initialListBuilder.build());
authority = entries.get().get(0).channel.authority();
this.channelRefreshExecutorService = channelRefreshExecutorService;

if (channelRefreshExecutorService != null) {
nextScheduledRefresh = scheduleNextRefresh();
this.executor = executor;

if (!settings.isStaticSize()) {
executor.scheduleAtFixedRate(
this::resizeSafely,
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
TimeUnit.SECONDS);
}
if (settings.isPreemptiveRefreshEnabled()) {
executor.scheduleAtFixedRate(
this::refreshSafely,
REFRESH_PERIOD.getSeconds(),
REFRESH_PERIOD.getSeconds(),
TimeUnit.SECONDS);
}
}

Expand Down Expand Up @@ -187,12 +189,9 @@ public ManagedChannel shutdown() {
for (Entry entry : localEntries) {
entry.channel.shutdown();
}
if (nextScheduledRefresh != null) {
nextScheduledRefresh.cancel(true);
}
if (channelRefreshExecutorService != null) {
if (executor != null) {
// shutdownNow will cancel scheduled tasks
channelRefreshExecutorService.shutdownNow();
executor.shutdownNow();
}
return this;
}
Expand All @@ -206,7 +205,7 @@ public boolean isShutdown() {
return false;
}
}
return channelRefreshExecutorService == null || channelRefreshExecutorService.isShutdown();
return executor == null || executor.isShutdown();
}

/** {@inheritDoc} */
Expand All @@ -218,7 +217,8 @@ public boolean isTerminated() {
return false;
}
}
return channelRefreshExecutorService == null || channelRefreshExecutorService.isTerminated();

return executor == null || executor.isTerminated();
}

/** {@inheritDoc} */
Expand All @@ -228,11 +228,8 @@ public ManagedChannel shutdownNow() {
for (Entry entry : localEntries) {
entry.channel.shutdownNow();
}
if (nextScheduledRefresh != null) {
nextScheduledRefresh.cancel(true);
}
if (channelRefreshExecutorService != null) {
channelRefreshExecutorService.shutdownNow();
if (executor != null) {
executor.shutdownNow();
}
return this;
}
Expand All @@ -249,25 +246,131 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}
entry.channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
if (channelRefreshExecutorService != null) {
if (executor != null) {
long awaitTimeNanos = endTimeNanos - System.nanoTime();
channelRefreshExecutorService.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
executor.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
return isTerminated();
}

/** Scheduling loop. */
private ScheduledFuture<?> scheduleNextRefresh() {
long delayPeriod = REFRESH_PERIOD.toMillis();
long jitter = (long) ((Math.random() - 0.5) * JITTER_PERCENTAGE * delayPeriod);
long delay = jitter + delayPeriod;
return channelRefreshExecutorService.schedule(
() -> {
scheduleNextRefresh();
refresh();
},
delay,
TimeUnit.MILLISECONDS);
private void resizeSafely() {
try {
synchronized (entryWriteLock) {
resize();
}
} catch (Exception e) {
LOG.log(Level.WARNING, "Failed to resize channel pool", e);
}
}

/**
* Resize the number of channels based on the number of outstanding RPCs.
*
* <p>This method is expected to be called on a fixed interval. On every invocation it will:
*
* <ul>
* <li>Get the maximum number of outstanding RPCs since last invocation
* <li>Determine a valid range of number of channels to handle that many outstanding RPCs
* <li>If the current number of channel falls outside of that range, add or remove at most
* {@link ChannelPoolSettings#MAX_RESIZE_DELTA} to get closer to middle of that range.
* </ul>
*
* <p>Not threadsafe, must be called under the entryWriteLock monitor
*/
@VisibleForTesting
void resize() {
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
List<Entry> localEntries = entries.get();
// Estimate the peak of RPCs in the last interval by summing the peak of RPCs per channel
int actualOutstandingRpcs =
localEntries.stream().mapToInt(Entry::getAndResetMaxOutstanding).sum();

// Number of channels if each channel operated at max capacity
int minChannels =
(int) Math.ceil(actualOutstandingRpcs / (double) settings.getMaxRpcsPerChannel());
// Limit the threshold to absolute range
if (minChannels < settings.getMinChannelCount()) {
minChannels = settings.getMinChannelCount();
}

// Number of channels if each channel operated at minimum capacity
// Note: getMinRpcsPerChannel() can return 0, but division by 0 shouldn't cause a problem.
int maxChannels =
(int) Math.ceil(actualOutstandingRpcs / (double) settings.getMinRpcsPerChannel());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see by default getMinRpcsPerChannel() returns 0, so we are diving by 0. That doesn't fail, but the result is

  • if actualOutstandingRpcs == 0, then maxChannel is 0.
  • if actualOutstandingRpcs != 0, then it's 2147483647.

This is still OK, since you're bounding maxChannels below. However, the division-by-0 computation is not trivial to verify and the computation result of 2147483647 seems atypical. The code gives the impression that the author probably failed to anticipate possible division by 0. At least I want to add a comment like "getMinRpcsPerChannel() can return 0, but division by 0 shouldn't cause a problem."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

// Limit the threshold to absolute range
if (maxChannels > settings.getMaxChannelCount()) {
maxChannels = settings.getMaxChannelCount();
}
if (maxChannels < minChannels) {
maxChannels = minChannels;
}

// If the pool were to be resized, try to aim for the middle of the bound, but limit rate of
// change.
int tentativeTarget = (maxChannels + minChannels) / 2;
int currentSize = localEntries.size();
int delta = tentativeTarget - currentSize;
int dampenedTarget = tentativeTarget;
if (Math.abs(delta) > ChannelPoolSettings.MAX_RESIZE_DELTA) {
dampenedTarget =
currentSize + (int) Math.copySign(ChannelPoolSettings.MAX_RESIZE_DELTA, delta);
}

// Only resize the pool when thresholds are crossed
if (localEntries.size() < minChannels) {
LOG.fine(
String.format(
"Detected throughput peak of %d, expanding channel pool size: %d -> %d.",
actualOutstandingRpcs, currentSize, dampenedTarget));

expand(dampenedTarget);
} else if (localEntries.size() > maxChannels) {
LOG.fine(
String.format(
"Detected throughput drop to %d, shrinking channel pool size: %d -> %d.",
actualOutstandingRpcs, currentSize, dampenedTarget));

shrink(dampenedTarget);
}
}

/** Not threadsafe, must be called under the entryWriteLock monitor */
private void shrink(int desiredSize) {
ImmutableList<Entry> localEntries = entries.get();
Preconditions.checkState(
localEntries.size() >= desiredSize, "current size is already smaller than the desired");

// Set the new list
entries.set(localEntries.subList(0, desiredSize));
// clean up removed entries
List<Entry> removed = localEntries.subList(desiredSize, localEntries.size());
removed.forEach(Entry::requestShutdown);
}

/** Not threadsafe, must be called under the entryWriteLock monitor */
private void expand(int desiredSize) {
List<Entry> localEntries = entries.get();
Preconditions.checkState(
localEntries.size() <= desiredSize, "current size is already bigger than the desired");

ImmutableList.Builder<Entry> newEntries = ImmutableList.<Entry>builder().addAll(localEntries);

for (int i = 0; i < desiredSize - localEntries.size(); i++) {
try {
newEntries.add(new Entry(channelFactory.createSingleChannel()));
} catch (IOException e) {
LOG.log(Level.WARNING, "Failed to add channel", e);
}
}

entries.set(newEntries.build());
}

private void refreshSafely() {
try {
refresh();
} catch (Exception e) {
LOG.log(Level.WARNING, "Failed to pre-emptively refresh channnels", e);
}
}

/**
Expand Down Expand Up @@ -341,13 +444,15 @@ private Entry getEntry(int affinity) {
List<Entry> localEntries = entries.get();

int index = Math.abs(affinity % localEntries.size());

return localEntries.get(index);
}

/** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */
private static class Entry {
private final ManagedChannel channel;
private final AtomicInteger outstandingRpcs = new AtomicInteger(0);
private final AtomicInteger maxOutstanding = new AtomicInteger();

// Flag that the channel should be closed once all of the outstanding RPC complete.
private final AtomicBoolean shutdownRequested = new AtomicBoolean();
Expand All @@ -358,6 +463,10 @@ private Entry(ManagedChannel channel) {
this.channel = channel;
}

int getAndResetMaxOutstanding() {
return maxOutstanding.getAndSet(outstandingRpcs.get());
}

/**
* Try to increment the outstanding RPC count. The method will return false if the channel is
* closing and the caller should pick a different channel. If the method returned true, the
Expand All @@ -366,7 +475,13 @@ private Entry(ManagedChannel channel) {
*/
private boolean retain() {
// register desire to start RPC
outstandingRpcs.incrementAndGet();
int currentOutstanding = outstandingRpcs.incrementAndGet();

// Rough book keeping
int prevMax = maxOutstanding.get();
if (currentOutstanding > prevMax) {
maxOutstanding.incrementAndGet();
}

// abort if the channel is closing
if (shutdownRequested.get()) {
Expand Down
Loading