From c771c43b870fb8618db7bdab6725ab40cac4976d Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Tue, 22 Mar 2022 00:37:29 +0800 Subject: [PATCH] Remote: Fix crashes by InterruptedException when dynamic execution is enabled. (#15091) Fixes #14433. The root cause is, inside `RemoteExecutionCache`, the result of `FindMissingDigests` is shared with other threads without considering error handling. For example, if there are two or more threads uploading the same input and one thread got interrupted when waiting for the result of `FindMissingDigests` call, the call is cancelled and others threads still waiting for the upload will receive upload error due to the cancellation which is wrong. This PR fixes this by effectively applying reference count to the result of `FindMissingDigests` call so that if one thread got interrupted, as long as there are other threads depending on the result, the call won't be cancelled and the upload can continue. Closes #15001. PiperOrigin-RevId: 436180205 --- .../lib/remote/RemoteExecutionCache.java | 204 +++++++++++++----- .../build/lib/remote/util/AsyncTaskCache.java | 33 ++- .../remote/RemoteExecutionServiceTest.java | 83 ++++++- .../lib/remote/util/InMemoryCacheClient.java | 28 ++- 4 files changed, 270 insertions(+), 78 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java index 229163eb37d05d..5474f884233832 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionCache.java @@ -13,30 +13,39 @@ // limitations under the License. package com.google.devtools.build.lib.remote; -import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static com.google.devtools.build.lib.remote.util.RxFutures.toCompletable; +import static com.google.devtools.build.lib.remote.util.RxFutures.toSingle; +import static com.google.devtools.build.lib.remote.util.RxUtils.mergeBulkTransfer; +import static com.google.devtools.build.lib.remote.util.RxUtils.toTransferResult; import static java.lang.String.format; import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.Directory; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; import com.google.devtools.build.lib.remote.merkletree.MerkleTree; import com.google.devtools.build.lib.remote.merkletree.MerkleTree.PathOrBytes; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.DigestUtil; -import com.google.devtools.build.lib.remote.util.RxFutures; +import com.google.devtools.build.lib.remote.util.RxUtils.TransferResult; import com.google.protobuf.Message; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subjects.AsyncSubject; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.HashSet; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.concurrent.GuardedBy; /** A {@link RemoteCache} with additional functionality needed for remote execution. */ public class RemoteExecutionCache extends RemoteCache { @@ -72,62 +81,58 @@ public void ensureInputsPresent( .addAll(additionalInputs.keySet()) .build(); - // Collect digests that are not being or already uploaded - ConcurrentHashMap> missingDigestSubjects = - new ConcurrentHashMap<>(); - - List> uploadFutures = new ArrayList<>(); - for (Digest digest : allDigests) { - Completable upload = - casUploadCache.execute( - digest, - Completable.defer( - () -> { - // The digest hasn't been processed, add it to the collection which will be used - // later for findMissingDigests call - AsyncSubject missingDigestSubject = AsyncSubject.create(); - missingDigestSubjects.put(digest, missingDigestSubject); - - return missingDigestSubject.flatMapCompletable( - missing -> { - if (!missing) { - return Completable.complete(); - } - return RxFutures.toCompletable( - () -> uploadBlob(context, digest, merkleTree, additionalInputs), - MoreExecutors.directExecutor()); - }); - }), - force); - uploadFutures.add(RxFutures.toListenableFuture(upload)); + if (allDigests.isEmpty()) { + return; } - ImmutableSet missingDigests; - try { - missingDigests = getFromFuture(findMissingDigests(context, missingDigestSubjects.keySet())); - } catch (IOException | InterruptedException e) { - for (Map.Entry> entry : missingDigestSubjects.entrySet()) { - entry.getValue().onError(e); - } + MissingDigestFinder missingDigestFinder = new MissingDigestFinder(context, allDigests.size()); + Flowable uploads = + Flowable.fromIterable(allDigests) + .flatMapSingle( + digest -> + uploadBlobIfMissing( + context, merkleTree, additionalInputs, force, missingDigestFinder, digest)); - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); + try { + mergeBulkTransfer(uploads).blockingAwait(); + } catch (RuntimeException e) { + Throwable cause = e.getCause(); + if (cause != null) { + Throwables.throwIfInstanceOf(cause, InterruptedException.class); + Throwables.throwIfInstanceOf(cause, IOException.class); } throw e; } + } - for (Map.Entry> entry : missingDigestSubjects.entrySet()) { - AsyncSubject missingSubject = entry.getValue(); - if (missingDigests.contains(entry.getKey())) { - missingSubject.onNext(true); - } else { - // The digest is already existed in the remote cache, skip the upload. - missingSubject.onNext(false); - } - missingSubject.onComplete(); - } - - waitForBulkTransfer(uploadFutures, /* cancelRemainingOnInterrupt=*/ false); + private Single uploadBlobIfMissing( + RemoteActionExecutionContext context, + MerkleTree merkleTree, + Map additionalInputs, + boolean force, + MissingDigestFinder missingDigestFinder, + Digest digest) { + Completable upload = + casUploadCache.execute( + digest, + Completable.defer( + () -> + // Only reach here if the digest is missing and is not being uploaded. + missingDigestFinder + .registerAndCount(digest) + .flatMapCompletable( + missingDigests -> { + if (missingDigests.contains(digest)) { + return toCompletable( + () -> uploadBlob(context, digest, merkleTree, additionalInputs), + directExecutor()); + } else { + return Completable.complete(); + } + })), + /* onIgnored= */ missingDigestFinder::count, + force); + return toTransferResult(upload); } private ListenableFuture uploadBlob( @@ -159,4 +164,93 @@ private ListenableFuture uploadBlob( "findMissingDigests returned a missing digest that has not been requested: %s", digest))); } + + /** + * A missing digest finder that initiates the request when the internal counter reaches an + * expected count. + */ + class MissingDigestFinder { + private final int expectedCount; + + private final AsyncSubject> digestsSubject; + private final Single> resultSingle; + + @GuardedBy("this") + private final Set digests; + + @GuardedBy("this") + private int currentCount = 0; + + MissingDigestFinder(RemoteActionExecutionContext context, int expectedCount) { + checkArgument(expectedCount > 0, "expectedCount should be greater than 0"); + this.expectedCount = expectedCount; + this.digestsSubject = AsyncSubject.create(); + this.digests = new HashSet<>(); + + AtomicBoolean findMissingDigestsCalled = new AtomicBoolean(false); + this.resultSingle = + Single.fromObservable( + digestsSubject + .flatMapSingle( + digests -> { + boolean wasCalled = findMissingDigestsCalled.getAndSet(true); + // Make sure we don't have re-subscription caused by refCount() below. + checkState(!wasCalled, "FindMissingDigests is called more than once"); + return toSingle( + () -> findMissingDigests(context, digests), directExecutor()); + }) + // Use replay here because we could have a race condition that downstream hasn't + // been added to the subscription list (to receive the upstream result) while + // upstream is completed. + .replay(1) + .refCount()); + } + + /** + * Register the {@code digest} and increase the counter. + * + *

Returned Single cannot be subscribed more than once. + * + * @return Single that emits the result of the {@code FindMissingDigest} request. + */ + Single> registerAndCount(Digest digest) { + AtomicBoolean subscribed = new AtomicBoolean(false); + // count() will potentially trigger the findMissingDigests call. Adding and counting before + // returning the Single could introduce a race that the result of findMissingDigests is + // available but the consumer doesn't get it because it hasn't subscribed the returned + // Single. In this case, it subscribes after upstream is completed resulting a re-run of + // findMissingDigests (due to refCount()). + // + // Calling count() inside doOnSubscribe to ensure the consumer already subscribed to the + // returned Single to avoid a re-execution of findMissingDigests. + return resultSingle.doOnSubscribe( + d -> { + boolean wasSubscribed = subscribed.getAndSet(true); + checkState(!wasSubscribed, "Single is subscribed more than once"); + synchronized (this) { + digests.add(digest); + } + count(); + }); + } + + /** Increase the counter. */ + void count() { + ImmutableSet digestsResult = null; + + synchronized (this) { + if (currentCount < expectedCount) { + currentCount++; + if (currentCount == expectedCount) { + digestsResult = ImmutableSet.copyOf(digests); + } + } + } + + if (digestsResult != null) { + digestsSubject.onNext(digestsResult); + digestsSubject.onComplete(); + } + } + } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index 8fb6f4ce20d49f..31369ef4ee1eab 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -24,6 +24,7 @@ import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.disposables.Disposable; +import io.reactivex.rxjava3.functions.Action; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -256,14 +257,25 @@ public boolean isDisposed() { /** * Executes a task. * + * @see #execute(Object, Single, Action, boolean). + */ + public Single execute(KeyT key, Single task, boolean force) { + return execute(key, task, () -> {}, force); + } + + /** + * Executes a task. If the task has already finished, this execution of the task is ignored unless + * `force` is true. If the task is in progress this execution of the task is always ignored. + * *

If the cache is already shutdown, a {@link CancellationException} will be emitted. * * @param key identifies the task. + * @param onIgnored callback called when provided task is ignored. * @param force re-execute a finished task if set to {@code true}. * @return a {@link Single} which turns to completed once the task is finished or propagates the * error if any. */ - public Single execute(KeyT key, Single task, boolean force) { + public Single execute(KeyT key, Single task, Action onIgnored, boolean force) { return Single.create( emitter -> { synchronized (lock) { @@ -273,14 +285,20 @@ public Single execute(KeyT key, Single task, boolean force) { } if (!force && finished.containsKey(key)) { + onIgnored.run(); emitter.onSuccess(finished.get(key)); return; } finished.remove(key); - Execution execution = - inProgress.computeIfAbsent(key, ignoredKey -> new Execution(key, task)); + Execution execution = inProgress.get(key); + if (execution != null) { + onIgnored.run(); + } else { + execution = new Execution(key, task); + inProgress.put(key, execution); + } // We must subscribe the execution within the scope of lock to avoid race condition // that: @@ -425,10 +443,15 @@ public Completable executeIfNot(KeyT key, Completable task) { cache.executeIfNot(key, task.toSingleDefault(Optional.empty()))); } - /** Same as {@link AsyncTaskCache#executeIfNot} but operates on {@link Completable}. */ + /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ public Completable execute(KeyT key, Completable task, boolean force) { + return execute(key, task, () -> {}, force); + } + + /** Same as {@link AsyncTaskCache#execute} but operates on {@link Completable}. */ + public Completable execute(KeyT key, Completable task, Action onIgnored, boolean force) { return Completable.fromSingle( - cache.execute(key, task.toSingleDefault(Optional.empty()), force)); + cache.execute(key, task.toSingleDefault(Optional.empty()), onIgnored, force)); } /** Returns a set of keys for tasks which is finished. */ diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java index b885e3b53d70d4..18b5a688106763 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteExecutionServiceTest.java @@ -51,6 +51,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.eventbus.EventBus; import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.SettableFuture; import com.google.devtools.build.lib.actions.ActionInput; import com.google.devtools.build.lib.actions.ActionInputHelper; import com.google.devtools.build.lib.actions.ActionUploadFinishedEvent; @@ -109,6 +110,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; @@ -1433,19 +1435,18 @@ public void uploadInputsIfNotPresent_deduplicateFindMissingBlobCalls() throws Ex ActionInput input = ActionInputHelper.fromPath("inputs/foo"); Digest inputDigest = fakeFileCache.createScratchInput(input, "input-foo"); RemoteExecutionService service = newRemoteExecutionService(); + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); for (int i = 0; i < taskCount; ++i) { executorService.execute( () -> { try { - Spawn spawn = - newSpawn( - ImmutableMap.of(), - ImmutableSet.of(), - NestedSetBuilder.create(Order.STABLE_ORDER, input)); - FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); - RemoteAction action = service.buildRemoteAction(spawn, context); - service.uploadInputsIfNotPresent(action, /*force=*/ false); } catch (Throwable e) { if (e instanceof InterruptedException) { @@ -1466,6 +1467,72 @@ public void uploadInputsIfNotPresent_deduplicateFindMissingBlobCalls() throws Ex } } + @Test + public void uploadInputsIfNotPresent_sameInputs_interruptOne_keepOthers() throws Exception { + int taskCount = 100; + ExecutorService executorService = Executors.newFixedThreadPool(taskCount); + AtomicReference error = new AtomicReference<>(null); + Semaphore semaphore = new Semaphore(0); + ActionInput input = ActionInputHelper.fromPath("inputs/foo"); + fakeFileCache.createScratchInput(input, "input-foo"); + RemoteExecutionService service = newRemoteExecutionService(); + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); + Random random = new Random(); + + for (int i = 0; i < taskCount; ++i) { + boolean shouldInterrupt = random.nextBoolean(); + executorService.execute( + () -> { + try { + if (shouldInterrupt) { + Thread.currentThread().interrupt(); + } + service.uploadInputsIfNotPresent(action, /*force=*/ false); + } catch (Throwable e) { + if (!(shouldInterrupt && e instanceof InterruptedException)) { + error.set(e); + } + } finally { + semaphore.release(); + } + }); + } + semaphore.acquire(taskCount); + + assertThat(error.get()).isNull(); + } + + @Test + public void uploadInputsIfNotPresent_interrupted_requestCancelled() throws Exception { + SettableFuture> future = SettableFuture.create(); + doReturn(future).when(cache).findMissingDigests(any(), any()); + ActionInput input = ActionInputHelper.fromPath("inputs/foo"); + fakeFileCache.createScratchInput(input, "input-foo"); + RemoteExecutionService service = newRemoteExecutionService(); + Spawn spawn = + newSpawn( + ImmutableMap.of(), + ImmutableSet.of(), + NestedSetBuilder.create(Order.STABLE_ORDER, input)); + FakeSpawnExecutionContext context = newSpawnExecutionContext(spawn); + RemoteAction action = service.buildRemoteAction(spawn, context); + + try { + Thread.currentThread().interrupt(); + service.uploadInputsIfNotPresent(action, /*force=*/ false); + } catch (InterruptedException ignored) { + // Intentionally left empty + } + + assertThat(future.isCancelled()).isTrue(); + } + @Test public void buildMerkleTree_withMemoization_works() throws Exception { // Test that Merkle tree building can be memoized. diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java index c26629f2b3cf14..8925640c11ccbc 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/InMemoryCacheClient.java @@ -19,6 +19,8 @@ import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; @@ -31,12 +33,15 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; /** A {@link RemoteCacheClient} that stores its contents in memory. */ public final class InMemoryCacheClient implements RemoteCacheClient { + private final ListeningExecutorService executorService = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(100)); private final ConcurrentMap downloadFailures = new ConcurrentHashMap<>(); private final ConcurrentMap ac = new ConcurrentHashMap<>(); private final ConcurrentMap cas; @@ -142,16 +147,19 @@ public ListenableFuture uploadBlob( @Override public ListenableFuture> findMissingDigests( RemoteActionExecutionContext context, Iterable digests) { - ImmutableSet.Builder missingBuilder = ImmutableSet.builder(); - for (Digest digest : digests) { - numFindMissingDigests - .computeIfAbsent(digest, (key) -> new AtomicInteger(0)) - .incrementAndGet(); - if (!cas.containsKey(digest)) { - missingBuilder.add(digest); - } - } - return Futures.immediateFuture(missingBuilder.build()); + return executorService.submit( + () -> { + ImmutableSet.Builder missingBuilder = ImmutableSet.builder(); + for (Digest digest : digests) { + numFindMissingDigests + .computeIfAbsent(digest, (key) -> new AtomicInteger(0)) + .incrementAndGet(); + if (!cas.containsKey(digest)) { + missingBuilder.add(digest); + } + } + return missingBuilder.build(); + }); } @Override