Skip to content

Commit

Permalink
Remote: Fix an issue that a failed action could lead to RuntimeExcept…
Browse files Browse the repository at this point in the history
…ion caused by InterruptedException thrown when acquiring gRPC connections. #13239

When --keep_going is not enabled, Bazel will cancel other executing actions if an action failed. An action which is executing remotely could in the state of waiting for a lock available to acquire the gRPC connection. SharedConnectionFactory uses ReentrantLock#lockInterruptibly to acquire the lock and will throw InterruptedException when the thread is interrupted which happens when the action is cancelled by Bazel. However, this InterruptedException is wrapped inside a RuntimeException results in a build error.

ReentrantLock was choosen initially to implement a hand-over-hand locking algorithem but it's no longer necessary after a few iterations. This change replaces ReentrantLock with `synchronized` keyword so we won't throw InterruptedException when acquiring gRPC connections. Call sites can still throw InterruptedException to cancel an action execution.

PiperOrigin-RevId: 365170212
  • Loading branch information
coeuvre authored and philwo committed Apr 19, 2021
1 parent c05fef4 commit b0d9fec
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

Expand All @@ -44,10 +43,9 @@ public class SharedConnectionFactory implements ConnectionPool {
private final ConnectionFactory factory;

@Nullable
@GuardedBy("connectionLock")
@GuardedBy("this")
private AsyncSubject<Connection> connectionAsyncSubject = null;

private final ReentrantLock connectionLock = new ReentrantLock();
private final AtomicReference<Disposable> connectionCreationDisposable =
new AtomicReference<>(null);

Expand All @@ -70,9 +68,7 @@ public void close() throws IOException {
d.dispose();
}

try {
connectionLock.lockInterruptibly();

synchronized (this) {
if (connectionAsyncSubject != null) {
Connection connection = connectionAsyncSubject.getValue();
if (connection != null) {
Expand All @@ -83,16 +79,11 @@ public void close() throws IOException {
connectionAsyncSubject.onError(new IllegalStateException("closed"));
}
}
} catch (InterruptedException e) {
throw new IOException(e);
} finally {
connectionLock.unlock();
}
}

private AsyncSubject<Connection> createUnderlyingConnectionIfNot() throws InterruptedException {
connectionLock.lockInterruptibly();
try {
private AsyncSubject<Connection> createUnderlyingConnectionIfNot() {
synchronized (this) {
if (connectionAsyncSubject == null || connectionAsyncSubject.hasThrowable()) {
connectionAsyncSubject =
factory
Expand All @@ -103,14 +94,11 @@ private AsyncSubject<Connection> createUnderlyingConnectionIfNot() throws Interr
}

return connectionAsyncSubject;
} finally {
connectionLock.unlock();
}
}

private Single<? extends Connection> acquireConnection() {
return Single.fromCallable(this::createUnderlyingConnectionIfNot)
.flatMap(Single::fromObservable);
return Single.fromObservable(createUnderlyingConnectionIfNot());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ java_test(
deps = [
"//src/main/java/com/google/devtools/build/lib/remote/grpc",
"//src/test/java/com/google/devtools/build/lib:test_runner",
"//src/test/java/com/google/devtools/build/lib/remote/util",
"//third_party:guava",
"//third_party:junit4",
"//third_party:mockito",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
import static org.mockito.Mockito.when;

import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.observers.TestObserver;
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
import java.io.IOException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -42,28 +41,16 @@
@RunWith(JUnit4.class)
public class SharedConnectionFactoryTest {
@Rule public final MockitoRule mockito = MockitoJUnit.rule();

private final AtomicReference<Throwable> rxGlobalThrowable = new AtomicReference<>(null);
@Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule();

@Mock private Connection connection;
@Mock private ConnectionFactory connectionFactory;

@Before
public void setUp() {
RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set);

when(connectionFactory.create()).thenAnswer(invocation -> Single.just(connection));
}

@After
public void tearDown() throws Throwable {
// Make sure rxjava didn't receive global errors
Throwable t = rxGlobalThrowable.getAndSet(null);
if (t != null) {
throw t;
}
}

@Test
public void create_smoke() {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
Expand Down Expand Up @@ -125,32 +112,37 @@ public void create_belowMaxConcurrency_shareConnections() {

@Test
public void create_concurrentCreate_shareConnections() throws InterruptedException {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2);
Semaphore semaphore = new Semaphore(0);
AtomicBoolean finished = new AtomicBoolean(false);
Thread t =
new Thread(
() -> {
factory
.create()
.doOnSuccess(
conn -> {
assertThat(conn.getUnderlyingConnection()).isEqualTo(connection);
semaphore.release();
Thread.sleep(Integer.MAX_VALUE);
finished.set(true);
})
.blockingSubscribe();

finished.set(true);
});
t.start();
semaphore.acquire();
int maxConcurrency = 10;
SharedConnectionFactory factory =
new SharedConnectionFactory(connectionFactory, maxConcurrency);
AtomicReference<Throwable> error = new AtomicReference<>(null);
Runnable runnable =
() -> {
try {
TestObserver<SharedConnection> observer = factory.create().test();

observer
.assertNoErrors()
.assertValue(conn -> conn.getUnderlyingConnection() == connection)
.assertComplete();
} catch (Throwable e) {
error.set(e);
}
};
Thread[] threads = new Thread[maxConcurrency];
for (int i = 0; i < threads.length; ++i) {
threads[i] = new Thread(runnable);
}

TestObserver<SharedConnection> observer = factory.create().test();
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
thread.join();
}

observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
assertThat(finished.get()).isFalse();
assertThat(error.get()).isNull();
verify(connectionFactory, times(1)).create();
}

@Test
Expand Down

0 comments on commit b0d9fec

Please sign in to comment.