Skip to content

Commit

Permalink
Remote: Another attempt to fix the CancellationException error in Asy…
Browse files Browse the repository at this point in the history
…ncTaskCache caused by a race condition.

PiperOrigin-RevId: 382662649
  • Loading branch information
coeuvre authored and copybara-github committed Jul 2, 2021
1 parent faf188f commit 07a84ce
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@

import static com.google.common.base.Preconditions.checkState;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.reactivex.rxjava3.annotations.NonNull;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleObserver;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.subjects.AsyncSubject;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

Expand All @@ -55,7 +54,7 @@ public final class AsyncTaskCache<KeyT, ValueT> {
private final Map<KeyT, ValueT> finished;

@GuardedBy("lock")
private final Map<KeyT, Execution<ValueT>> inProgress;
private final Map<KeyT, Execution> inProgress;

public static <KeyT, ValueT> AsyncTaskCache<KeyT, ValueT> create() {
return new AsyncTaskCache<>();
Expand Down Expand Up @@ -91,79 +90,136 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
return execute(key, task, false);
}

private static class Execution<ValueT> {
private final AtomicBoolean isTaskDisposed = new AtomicBoolean(false);
private final Single<ValueT> task;
private final AsyncSubject<ValueT> asyncSubject = AsyncSubject.create();
private final AtomicInteger referenceCount = new AtomicInteger(0);
private final AtomicReference<Disposable> taskDisposable = new AtomicReference<>(null);
/** Returns count of subscribers for a task. */
public int getSubscriberCount(KeyT key) {
synchronized (lock) {
Execution task = inProgress.get(key);
if (task != null) {
return task.getSubscriberCount();
}
}

return 0;
}

class Execution extends Single<ValueT> implements SingleObserver<ValueT> {
private final KeyT key;
private final Single<ValueT> upstream;

@GuardedBy("lock")
private boolean terminated = false;

@GuardedBy("lock")
private Disposable upstreamDisposable;

Execution(Single<ValueT> task) {
this.task = task;
@GuardedBy("lock")
private final List<SingleObserver<? super ValueT>> observers = new ArrayList<>();

Execution(KeyT key, Single<ValueT> upstream) {
this.key = key;
this.upstream = upstream;
}

Single<ValueT> executeIfNot() {
checkState(!isTaskDisposed(), "disposed");

int subscribed = referenceCount.getAndIncrement();
if (taskDisposable.get() == null && subscribed == 0) {
task.subscribe(
new SingleObserver<ValueT>() {
@Override
public void onSubscribe(@NonNull Disposable d) {
taskDisposable.compareAndSet(null, d);
}

@Override
public void onSuccess(@NonNull ValueT value) {
asyncSubject.onNext(value);
asyncSubject.onComplete();
}

@Override
public void onError(@NonNull Throwable e) {
asyncSubject.onError(e);
}
});
int getSubscriberCount() {
synchronized (lock) {
return observers.size();
}
}

@Override
protected void subscribeActual(@NonNull SingleObserver<? super ValueT> observer) {
synchronized (lock) {
checkState(!terminated, "terminated");

boolean shouldSubscribe = observers.isEmpty();

observers.add(observer);

observer.onSubscribe(new ExecutionDisposable(this, observer));

return Single.fromObservable(asyncSubject);
if (shouldSubscribe) {
upstream.subscribe(this);
}
}
}

boolean isTaskTerminated() {
return asyncSubject.hasComplete() || asyncSubject.hasThrowable();
@Override
public void onSubscribe(@NonNull Disposable d) {
synchronized (lock) {
upstreamDisposable = d;

if (terminated) {
d.dispose();
}
}
}

boolean isTaskDisposed() {
return isTaskDisposed.get();
@Override
public void onSuccess(@NonNull ValueT value) {
synchronized (lock) {
if (!terminated) {
inProgress.remove(key);
finished.put(key, value);
terminated = true;

for (SingleObserver<? super ValueT> observer : ImmutableList.copyOf(observers)) {
observer.onSuccess(value);
}
}
}
}

void tryDisposeTask() {
checkState(!isTaskDisposed(), "disposed");
checkState(!isTaskTerminated(), "terminated");
@Override
public void onError(@NonNull Throwable error) {
synchronized (lock) {
if (!terminated) {
inProgress.remove(key);
terminated = true;

if (referenceCount.decrementAndGet() == 0) {
isTaskDisposed.set(true);
asyncSubject.onError(new CancellationException("disposed"));
for (SingleObserver<? super ValueT> observer : ImmutableList.copyOf(observers)) {
observer.onError(error);
}
}
}
}

Disposable d = taskDisposable.get();
if (d != null) {
d.dispose();
void remove(SingleObserver<? super ValueT> observer) {
synchronized (lock) {
observers.remove(observer);

if (observers.isEmpty() && !terminated) {
inProgress.remove(key);
terminated = true;

if (upstreamDisposable != null) {
upstreamDisposable.dispose();
}
}
}
}
}

/** Returns count of subscribers for a task. */
public int getSubscriberCount(KeyT key) {
synchronized (lock) {
Execution<ValueT> execution = inProgress.get(key);
if (execution != null) {
return execution.referenceCount.get();
class ExecutionDisposable implements Disposable {
final Execution execution;
final SingleObserver<? super ValueT> observer;
AtomicBoolean isDisposed = new AtomicBoolean(false);

ExecutionDisposable(Execution execution, SingleObserver<? super ValueT> observer) {
this.execution = execution;
this.observer = observer;
}

@Override
public void dispose() {
if (isDisposed.compareAndSet(false, true)) {
execution.remove(observer);
}
}

return 0;
@Override
public boolean isDisposed() {
return isDisposed.get();
}
}

/**
Expand All @@ -185,62 +241,34 @@ public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {

finished.remove(key);

Execution<ValueT> execution =
inProgress.computeIfAbsent(
key,
ignoredKey -> {
AtomicInteger subscribeTimes = new AtomicInteger(0);
return new Execution<>(
Single.defer(
() -> {
int times = subscribeTimes.incrementAndGet();
checkState(times == 1, "Subscribed more than once to the task");
return task;
}));
});

execution
.executeIfNot()
.subscribe(
new SingleObserver<ValueT>() {
@Override
public void onSubscribe(@NonNull Disposable d) {
emitter.setCancellable(
() -> {
d.dispose();

if (!execution.isTaskTerminated()) {
synchronized (lock) {
execution.tryDisposeTask();
if (execution.isTaskDisposed()) {
inProgress.remove(key);
}
}
}
});
}

@Override
public void onSuccess(@NonNull ValueT value) {
synchronized (lock) {
finished.put(key, value);
inProgress.remove(key);
}

emitter.onSuccess(value);
}

@Override
public void onError(@NonNull Throwable e) {
synchronized (lock) {
inProgress.remove(key);
}

if (!emitter.isDisposed()) {
emitter.onError(e);
}
}
});
Execution execution =
inProgress.computeIfAbsent(key, ignoredKey -> new Execution(key, task));

// We must subscribe the execution within the scope of lock to avoid race condition
// that:
// 1. Two callers get the same execution instance
// 2. One decides to dispose the execution, since no more observers, the execution
// will change to the terminate state
// 3. Another one try to subscribe, will get "terminated" error.
execution.subscribe(
new SingleObserver<ValueT>() {
@Override
public void onSubscribe(@NonNull Disposable d) {
emitter.setDisposable(d);
}

@Override
public void onSuccess(@NonNull ValueT valueT) {
emitter.onSuccess(valueT);
}

@Override
public void onError(@NonNull Throwable e) {
if (!emitter.isDisposed()) {
emitter.onError(e);
}
}
});
}
});
}
Expand Down
Loading

0 comments on commit 07a84ce

Please sign in to comment.