Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
fix: do not override grpc default executor
Browse files Browse the repository at this point in the history
  • Loading branch information
mutianf committed Apr 28, 2021
1 parent 52e39f8 commit 73bc3ce
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
: builder.directPathServiceConfig;
}

@Deprecated
@Override
public boolean needsExecutor() {
return executor == null;
Expand Down Expand Up @@ -200,9 +201,7 @@ public TransportChannelProvider withCredentials(Credentials credentials) {

@Override
public TransportChannel getTransportChannel() throws IOException {
if (needsExecutor()) {
throw new IllegalStateException("getTransportChannel() called when needsExecutor() is true");
} else if (needsHeaders()) {
if (needsHeaders()) {
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
} else if (needsEndpoint()) {
throw new IllegalStateException("getTransportChannel() called when needsEndpoint() is true");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,36 @@

import com.google.api.core.ApiFunction;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import com.google.type.Color;
import com.google.type.Money;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -443,6 +455,62 @@ public void testWithDefaultDirectPathServiceConfig() {
assertThat(childPolicy.keySet()).containsExactly("pick_first");
}

@Test
public void testDefaultExecutor() throws Exception {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:1234")
.setHeaderProvider(FixedHeaderProvider.create())
.build();

// The default name thread name for grpc threads configured in GrpcUtil
assertThat(extractExecutorThreadName(provider)).contains("grpc-default-executor");
}

/**
* Extract the name of the channel executor thread by instantiating a channel and issuing a fake
* call.
*/
private static String extractExecutorThreadName(InstantiatingGrpcChannelProvider channelProvider)
throws IOException, ExecutionException, InterruptedException {
GrpcTransportChannel transportChannel =
(GrpcTransportChannel) channelProvider.getTransportChannel();
try {
Channel channel = transportChannel.getChannel();

ClientCall<com.google.type.Color, Money> call =
channel.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT);
Color request = Color.getDefaultInstance();

final SettableFuture<String> threadNameFuture = SettableFuture.create();

// Issue a call just to get the thread name of the channel executor
ClientCalls.asyncUnaryCall(
call,
request,
new StreamObserver<Money>() {
@Override
public void onNext(Money ignored) {
threadNameFuture.set(Thread.currentThread().getName());
}

@Override
public void onError(Throwable ignored) {
threadNameFuture.set(Thread.currentThread().getName());
}

@Override
public void onCompleted() {
threadNameFuture.set(Thread.currentThread().getName());
}
});
return threadNameFuture.get();
} finally {
transportChannel.shutdown();
transportChannel.awaitTermination(10, TimeUnit.SECONDS);
}
}

@Nullable
private static Map<String, ?> getAsObject(Map<String, ?> json, String key) {
Object mapObject = json.get(key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public boolean shouldAutoClose() {
return true;
}

@Deprecated
@Override
public boolean needsExecutor() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ private InstantiatingHttpJsonChannelProvider(
this.httpTransport = httpTransport;
}

@Deprecated
@Override
public boolean needsExecutor() {
return executor == null;
Expand Down Expand Up @@ -140,9 +141,7 @@ public String getTransportName() {

@Override
public TransportChannel getTransportChannel() throws IOException {
if (needsExecutor()) {
throw new IllegalStateException("getTransportChannel() called when needsExecutor() is true");
} else if (needsHeaders()) {
if (needsHeaders()) {
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
} else {
return createChannel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,26 @@
import com.google.api.gax.core.BackgroundResource;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/** Implementation of HttpJsonChannel which can issue http-json calls. */
@BetaApi
public class ManagedHttpJsonChannel implements HttpJsonChannel, BackgroundResource {
private static final JsonFactory JSON_FACTORY = GsonFactory.getDefaultInstance();
private static final ExecutorService DEFAULT_EXECUTOR =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("http-default-executor-%d")
.build());

private final Executor executor;
private final String endpoint;
Expand Down Expand Up @@ -134,7 +143,9 @@ public boolean awaitTermination(long duration, TimeUnit unit) throws Interrupted
public void close() {}

public static Builder newBuilder() {
return new Builder().setHeaderEnhancers(new LinkedList<HttpJsonHeaderEnhancer>());
return new Builder()
.setHeaderEnhancers(new LinkedList<HttpJsonHeaderEnhancer>())
.setExecutor(DEFAULT_EXECUTOR);
}

public static class Builder {
Expand All @@ -147,7 +158,11 @@ public static class Builder {
private Builder() {}

public Builder setExecutor(Executor executor) {
this.executor = executor;
if (executor != null) {
this.executor = executor;
} else {
this.executor = DEFAULT_EXECUTOR;
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,29 @@

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;

import com.google.api.core.ApiFuture;
import com.google.api.gax.httpjson.testing.MockHttpService;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

@RunWith(JUnit4.class)
public class InstantiatingHttpJsonChannelProviderTest {
Expand Down Expand Up @@ -94,4 +107,85 @@ public void basicTest() throws IOException {
// Make sure we can create channels OK.
provider.getTransportChannel().shutdownNow();
}

@Test
public void testDefaultExecutor() throws Exception {
// Create a mock service that will always return errors. We just want to inspect the thread that
// those errors are returned on
MockHttpService mockHttpService =
new MockHttpService(Collections.<ApiMethodDescriptor>emptyList(), "/");
mockHttpService.addException(new RuntimeException("Fake error"));
InstantiatingHttpJsonChannelProvider channelProvider =
InstantiatingHttpJsonChannelProvider.newBuilder()
.setEndpoint("localhost:1234")
.setHeaderProvider(FixedHeaderProvider.create())
.setHttpTransport(mockHttpService)
.build();

assertThat(getThreadName(channelProvider)).contains("http-default-executor");
}

@Test
public void testExecutorOverride() throws IOException, ExecutionException, InterruptedException {
MockHttpService mockHttpService =
new MockHttpService(Collections.<ApiMethodDescriptor>emptyList(), "/");
mockHttpService.addException(new RuntimeException("Fake error"));

final String expectedThreadName = "testExecutorOverrideExecutor";

ExecutorService executor =
Executors.newFixedThreadPool(
1,
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(expectedThreadName).build());
try {
InstantiatingHttpJsonChannelProvider channelProvider =
InstantiatingHttpJsonChannelProvider.newBuilder()
.setExecutor(executor)
.setEndpoint("localhost:1234")
.setHeaderProvider(FixedHeaderProvider.create())
.setHttpTransport(mockHttpService)
.build();

assertThat(getThreadName(channelProvider)).isEqualTo(expectedThreadName);
} finally {
executor.shutdown();
executor.awaitTermination(10, TimeUnit.SECONDS);
}
}

private static String getThreadName(InstantiatingHttpJsonChannelProvider provider)
throws IOException, InterruptedException, ExecutionException {
@SuppressWarnings("unchecked")
ApiMethodDescriptor<Object, Object> apiMethodDescriptor =
mock(
ApiMethodDescriptor.class,
new Answer() {
@Override
public Object answer(InvocationOnMock invocation) {
throw new UnsupportedOperationException("fake error");
}
});

HttpJsonTransportChannel transportChannel =
(HttpJsonTransportChannel) provider.getTransportChannel();
final SettableFuture<String> threadNameFuture = SettableFuture.create();
try {
HttpJsonChannel channel = transportChannel.getChannel();
ApiFuture<Object> rpcFuture =
channel.issueFutureUnaryCall(
HttpJsonCallOptions.newBuilder().build(), new Object(), apiMethodDescriptor);
rpcFuture.addListener(
new Runnable() {
@Override
public void run() {
threadNameFuture.set(Thread.currentThread().getName());
}
},
MoreExecutors.directExecutor());
} finally {
transportChannel.shutdown();
transportChannel.awaitTermination(10, TimeUnit.SECONDS);
}
return threadNameFuture.get();
}
}
25 changes: 16 additions & 9 deletions gax/src/main/java/com/google/api/gax/rpc/ClientContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -139,8 +138,13 @@ public static ClientContext create(ClientSettings settings) throws IOException {
public static ClientContext create(StubSettings settings) throws IOException {
ApiClock clock = settings.getClock();

ExecutorProvider executorProvider = settings.getExecutorProvider();
final ScheduledExecutorService executor = executorProvider.getExecutor();
ExecutorProvider workerExecutorProvider = settings.getWorkerExecutorProvider();
final ScheduledExecutorService workerExecutor = workerExecutorProvider.getExecutor();

final ScheduledExecutorService executor =
settings.getExecutorProvider() == null
? null
: settings.getExecutorProvider().getExecutor();

Credentials credentials = settings.getCredentialsProvider().getCredentials();

Expand All @@ -153,8 +157,11 @@ public static ClientContext create(StubSettings settings) throws IOException {
}

TransportChannelProvider transportChannelProvider = settings.getTransportChannelProvider();
if (transportChannelProvider.needsExecutor()) {
transportChannelProvider = transportChannelProvider.withExecutor((Executor) executor);
// After needsExecutor and StubSettings#setExecutor are deprecated, transport channel executor
// can only be set from TransportChannelProvider#withExecutor directly, and all providers will
// have default executors.
if (transportChannelProvider.needsExecutor() && executor != null) {
transportChannelProvider = transportChannelProvider.withExecutor(executor);
}
Map<String, String> headers = getHeadersFromSettings(settings);
if (transportChannelProvider.needsHeaders()) {
Expand Down Expand Up @@ -186,7 +193,7 @@ public static ClientContext create(StubSettings settings) throws IOException {
watchdogProvider = watchdogProvider.withClock(clock);
}
if (watchdogProvider.needsExecutor()) {
watchdogProvider = watchdogProvider.withExecutor(executor);
watchdogProvider = watchdogProvider.withExecutor(workerExecutor);
}
watchdog = watchdogProvider.getWatchdog();
}
Expand All @@ -196,16 +203,16 @@ public static ClientContext create(StubSettings settings) throws IOException {
if (transportChannelProvider.shouldAutoClose()) {
backgroundResources.add(transportChannel);
}
if (executorProvider.shouldAutoClose()) {
backgroundResources.add(new ExecutorAsBackgroundResource(executor));
if (workerExecutorProvider.shouldAutoClose()) {
backgroundResources.add(new ExecutorAsBackgroundResource(workerExecutor));
}
if (watchdogProvider != null && watchdogProvider.shouldAutoClose()) {
backgroundResources.add(watchdog);
}

return newBuilder()
.setBackgroundResources(backgroundResources.build())
.setExecutor(executor)
.setExecutor(workerExecutor)
.setCredentials(credentials)
.setTransportChannel(transportChannel)
.setHeaders(ImmutableMap.copyOf(settings.getHeaderProvider().getHeaders()))
Expand Down
Loading

0 comments on commit 73bc3ce

Please sign in to comment.