diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD index 8d6e346b48266c..72fcc865ca9152 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD @@ -18,6 +18,7 @@ java_library( "//src/main/java/com/google/devtools/build/lib/concurrent", "//third_party:guava", "//third_party:jsr305", + "//third_party:netty", "//third_party:rxjava3", "//third_party/grpc-java:grpc-jar", ], diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java index bb2dddb9d3a3f0..6452103b83b24b 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java @@ -13,10 +13,16 @@ // limitations under the License. package com.google.devtools.build.lib.remote.grpc; +import com.google.common.annotations.VisibleForTesting; import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.netty.channel.unix.Errors; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Action; @@ -25,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -41,6 +48,7 @@ public class SharedConnectionFactory implements ConnectionPool { private final TokenBucket tokenBucket; private final ConnectionFactory factory; + private final Predicate fatalErrorPredicate; @Nullable @GuardedBy("this") @@ -50,6 +58,12 @@ public class SharedConnectionFactory implements ConnectionPool { new AtomicReference<>(null); public SharedConnectionFactory(ConnectionFactory factory, int maxConcurrency) { + this(factory, maxConcurrency, SharedConnectionFactory::isFatalError); + } + + @VisibleForTesting + SharedConnectionFactory( + ConnectionFactory factory, int maxConcurrency, Predicate fatalErrorPredicate) { this.factory = factory; List initialTokens = new ArrayList<>(maxConcurrency); @@ -57,6 +71,7 @@ public SharedConnectionFactory(ConnectionFactory factory, int maxConcurrency) { initialTokens.add(i); } this.tokenBucket = new TokenBucket<>(initialTokens); + this.fatalErrorPredicate = fatalErrorPredicate; } @Override @@ -118,7 +133,14 @@ public Single create() { .map( conn -> new SharedConnection( - conn, /* onClose= */ () -> tokenBucket.addToken(token)))); + conn, + /* onClose= */ () -> tokenBucket.addToken(token), + fatalErrorPredicate, + /* onFatalError= */ () -> { + synchronized (this) { + connectionAsyncSubject = null; + } + }))); } /** Returns current number of available connections. */ @@ -130,16 +152,39 @@ public int numAvailableConnections() { public static class SharedConnection implements Connection { private final Connection connection; private final Action onClose; - - public SharedConnection(Connection connection, Action onClose) { + private final Predicate fatalErrorPredicate; + private final Runnable onFatalError; + + public SharedConnection( + Connection connection, + Action onClose, + Predicate fatalErrorPredicate, + Runnable onFatalError) { this.connection = connection; this.onClose = onClose; + this.fatalErrorPredicate = fatalErrorPredicate; + this.onFatalError = onFatalError; } @Override public ClientCall call( MethodDescriptor method, CallOptions options) { - return connection.call(method, options); + return new SimpleForwardingClientCall<>(connection.call(method, options)) { + @Override + public void start(Listener responseListener, Metadata headers) { + super.start( + new SimpleForwardingClientCallListener<>(responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + if (fatalErrorPredicate.test(status.getCause())) { + onFatalError.run(); + } + super.onClose(status, trailers); + } + }, + headers); + } + }; } @Override @@ -156,4 +201,10 @@ public Connection getUnderlyingConnection() { return connection; } } + + private static boolean isFatalError(@Nullable Throwable t) { + // A low-level netty error indicates that the connection is fundamentally broken + // and should not be reused for retries. + return t instanceof Errors.NativeIoException; + } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD index f8b88c67127d7b..e6626e8e3eecf1 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD @@ -29,5 +29,6 @@ java_test( "//third_party:mockito", "//third_party:rxjava3", "//third_party:truth", + "//third_party/grpc-java:grpc-jar", ], ) diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java index ad3f3c73f1999d..1b070b8b4dea0f 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java @@ -14,16 +14,27 @@ package com.google.devtools.build.lib.remote.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.observers.TestObserver; import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayDeque; +import java.util.Queue; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -145,6 +156,82 @@ public void create_concurrentCreate_shareConnections() throws InterruptedExcepti verify(connectionFactory, times(1)).create(); } + private static final class FatalIOException extends IOException { + FatalIOException() { + super("fatal"); + } + } + + @SuppressWarnings({"unchecked", "CannotMockFinalClass"}) + @Test + public void create_belowMaxConcurrency_fatalErrorPreventsReuse() throws IOException { + Connection brokenConnection = + new Connection() { + @Override + public ClientCall call( + MethodDescriptor method, CallOptions options) { + var call = mock(ClientCall.class); + doAnswer( + invocationOnMock -> { + ((ClientCall.Listener) invocationOnMock.getArgument(0)) + .onClose(Status.fromThrowable(new FatalIOException()), new Metadata()); + return null; + }) + .when(call) + .start(any(), any()); + return call; + } + + @Override + public void close() {} + }; + Connection newConnection = mock(Connection.class); + Queue connectionsToCreate = + new ArrayDeque<>(ImmutableList.of(brokenConnection, newConnection)); + when(connectionFactory.create()) + .thenAnswer(invocation -> Single.just(connectionsToCreate.remove())); + + SharedConnectionFactory factory = + new SharedConnectionFactory(connectionFactory, 2, t -> t instanceof FatalIOException); + + TestObserver observer1 = factory.create().test(); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + observer1 + .assertValue(conn -> conn.getUnderlyingConnection() == brokenConnection) + .assertComplete(); + + // Submit a call on the first connection and have it fail. + MethodDescriptor.Marshaller nullMarshaller = + new MethodDescriptor.Marshaller<>() { + @Override + public InputStream stream(byte[] bytes) { + return null; + } + + @Override + public byte[] parse(InputStream inputStream) { + return null; + } + }; + try (Connection firstConnection = observer1.values().getFirst()) { + var call = + firstConnection.call( + MethodDescriptor.newBuilder(nullMarshaller, nullMarshaller) + .setType(MethodDescriptor.MethodType.CLIENT_STREAMING) + .setFullMethodName("testMethod") + .build(), + CallOptions.DEFAULT); + ClientCall.Listener listener = new ClientCall.Listener<>() {}; + call.start(listener, new Metadata()); + listener.onClose(Status.fromThrowable(new FatalIOException()), new Metadata()); + } + + // Validate that the connection is not reused. + TestObserver observer2 = factory.create().test(); + observer2.assertValue(conn -> conn.getUnderlyingConnection() == newConnection).assertComplete(); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + } + @Test public void create_afterLastFailed_success() { AtomicInteger times = new AtomicInteger(0);