Skip to content

Commit

Permalink
[7.4.0] Do not reuse gRPC connections that fail with native Netty err…
Browse files Browse the repository at this point in the history
…ors (#23343)

Such connections are usually not in a recoverable state and should not
be used for retries, which would otherwise likely fail in the same way.

Fixes #20868

Closes #23150.

PiperOrigin-RevId: 662091153
Change-Id: Iaf160b11a13af013b9969c7fdaa966bca8ab6be2

Commit
06691b3

Co-authored-by: Fabian Meumertzheim <fabian@meumertzhe.im>
  • Loading branch information
bazel-io and fmeum authored Aug 21, 2024
1 parent a9a66ae commit 60cbce1
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/concurrent",
"//third_party:guava",
"//third_party:jsr305",
"//third_party:netty",
"//third_party:rxjava3",
"//third_party/grpc-java:grpc-jar",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;

import com.google.common.annotations.VisibleForTesting;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.netty.channel.unix.Errors;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.functions.Action;
Expand All @@ -25,6 +31,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

Expand All @@ -41,6 +48,7 @@
public class SharedConnectionFactory implements ConnectionPool {
private final TokenBucket<Integer> tokenBucket;
private final ConnectionFactory factory;
private final Predicate<Throwable> fatalErrorPredicate;

@Nullable
@GuardedBy("this")
Expand All @@ -50,13 +58,20 @@ public class SharedConnectionFactory implements ConnectionPool {
new AtomicReference<>(null);

public SharedConnectionFactory(ConnectionFactory factory, int maxConcurrency) {
this(factory, maxConcurrency, SharedConnectionFactory::isFatalError);
}

@VisibleForTesting
SharedConnectionFactory(
ConnectionFactory factory, int maxConcurrency, Predicate<Throwable> fatalErrorPredicate) {
this.factory = factory;

List<Integer> initialTokens = new ArrayList<>(maxConcurrency);
for (int i = 0; i < maxConcurrency; ++i) {
initialTokens.add(i);
}
this.tokenBucket = new TokenBucket<>(initialTokens);
this.fatalErrorPredicate = fatalErrorPredicate;
}

@Override
Expand Down Expand Up @@ -118,7 +133,14 @@ public Single<SharedConnection> create() {
.map(
conn ->
new SharedConnection(
conn, /* onClose= */ () -> tokenBucket.addToken(token))));
conn,
/* onClose= */ () -> tokenBucket.addToken(token),
fatalErrorPredicate,
/* onFatalError= */ () -> {
synchronized (this) {
connectionAsyncSubject = null;
}
})));
}

/** Returns current number of available connections. */
Expand All @@ -130,16 +152,39 @@ public int numAvailableConnections() {
public static class SharedConnection implements Connection {
private final Connection connection;
private final Action onClose;

public SharedConnection(Connection connection, Action onClose) {
private final Predicate<Throwable> fatalErrorPredicate;
private final Runnable onFatalError;

public SharedConnection(
Connection connection,
Action onClose,
Predicate<Throwable> fatalErrorPredicate,
Runnable onFatalError) {
this.connection = connection;
this.onClose = onClose;
this.fatalErrorPredicate = fatalErrorPredicate;
this.onFatalError = onFatalError;
}

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> call(
MethodDescriptor<ReqT, RespT> method, CallOptions options) {
return connection.call(method, options);
return new SimpleForwardingClientCall<>(connection.call(method, options)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
super.start(
new SimpleForwardingClientCallListener<>(responseListener) {
@Override
public void onClose(Status status, Metadata trailers) {
if (fatalErrorPredicate.test(status.getCause())) {
onFatalError.run();
}
super.onClose(status, trailers);
}
},
headers);
}
};
}

@Override
Expand All @@ -156,4 +201,10 @@ public Connection getUnderlyingConnection() {
return connection;
}
}

private static boolean isFatalError(@Nullable Throwable t) {
// A low-level netty error indicates that the connection is fundamentally broken
// and should not be reused for retries.
return t instanceof Errors.NativeIoException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ java_test(
"//third_party:mockito",
"//third_party:rxjava3",
"//third_party:truth",
"//third_party/grpc-java:grpc-jar",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,27 @@
package com.google.devtools.build.lib.remote.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.observers.TestObserver;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -145,6 +156,82 @@ public void create_concurrentCreate_shareConnections() throws InterruptedExcepti
verify(connectionFactory, times(1)).create();
}

private static final class FatalIOException extends IOException {
FatalIOException() {
super("fatal");
}
}

@SuppressWarnings({"unchecked", "CannotMockFinalClass"})
@Test
public void create_belowMaxConcurrency_fatalErrorPreventsReuse() throws IOException {
Connection brokenConnection =
new Connection() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> call(
MethodDescriptor<ReqT, RespT> method, CallOptions options) {
var call = mock(ClientCall.class);
doAnswer(
invocationOnMock -> {
((ClientCall.Listener) invocationOnMock.getArgument(0))
.onClose(Status.fromThrowable(new FatalIOException()), new Metadata());
return null;
})
.when(call)
.start(any(), any());
return call;
}

@Override
public void close() {}
};
Connection newConnection = mock(Connection.class);
Queue<Connection> connectionsToCreate =
new ArrayDeque<>(ImmutableList.of(brokenConnection, newConnection));
when(connectionFactory.create())
.thenAnswer(invocation -> Single.just(connectionsToCreate.remove()));

SharedConnectionFactory factory =
new SharedConnectionFactory(connectionFactory, 2, t -> t instanceof FatalIOException);

TestObserver<SharedConnection> observer1 = factory.create().test();
assertThat(factory.numAvailableConnections()).isEqualTo(1);
observer1
.assertValue(conn -> conn.getUnderlyingConnection() == brokenConnection)
.assertComplete();

// Submit a call on the first connection and have it fail.
MethodDescriptor.Marshaller<byte[]> nullMarshaller =
new MethodDescriptor.Marshaller<>() {
@Override
public InputStream stream(byte[] bytes) {
return null;
}

@Override
public byte[] parse(InputStream inputStream) {
return null;
}
};
try (Connection firstConnection = observer1.values().getFirst()) {
var call =
firstConnection.call(
MethodDescriptor.newBuilder(nullMarshaller, nullMarshaller)
.setType(MethodDescriptor.MethodType.CLIENT_STREAMING)
.setFullMethodName("testMethod")
.build(),
CallOptions.DEFAULT);
ClientCall.Listener<byte[]> listener = new ClientCall.Listener<>() {};
call.start(listener, new Metadata());
listener.onClose(Status.fromThrowable(new FatalIOException()), new Metadata());
}

// Validate that the connection is not reused.
TestObserver<SharedConnection> observer2 = factory.create().test();
observer2.assertValue(conn -> conn.getUnderlyingConnection() == newConnection).assertComplete();
assertThat(factory.numAvailableConnections()).isEqualTo(1);
}

@Test
public void create_afterLastFailed_success() {
AtomicInteger times = new AtomicInteger(0);
Expand Down

0 comments on commit 60cbce1

Please sign in to comment.