Skip to content

Commit

Permalink
Merge pull request #111 from michaelstonis/feature/latest-await-operator
Browse files Browse the repository at this point in the history
Adds support to only get the latest value when using Async operators
  • Loading branch information
neuecc authored Feb 18, 2024
2 parents fb04d0a + 86c54ca commit f9273a5
Show file tree
Hide file tree
Showing 9 changed files with 772 additions and 395 deletions.
780 changes: 391 additions & 389 deletions README.md

Large diffs are not rendered by default.

90 changes: 89 additions & 1 deletion src/R3/AwaitOperation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ public enum AwaitOperation
/// <summary>All values are sent immediately to the asynchronous method.</summary>
Parallel,
/// <summary>All values are sent immediately to the asynchronous method, but the results are queued and passed to the next operator in order.</summary>
SequentialParallel
SequentialParallel,
/// <summary>Only the latest value is queued, and the next value waits for the completion of the asynchronous method.</summary>
Latest,
}

internal abstract class AwaitOperationSequentialObserver<T> : Observer<T>
Expand Down Expand Up @@ -664,3 +666,89 @@ async void RunQueueWorker() // don't(can't) wait so use async void
}
}
}

internal abstract class AwaitOperationLatestObserver<T> : Observer<T>
{
readonly CancellationTokenSource cancellationTokenSource;
readonly bool configureAwait; // continueOnCapturedContext
readonly Channel<T> channel;
bool completed;

protected override bool AutoDisposeOnCompleted => false; // disable auto-dispose

public AwaitOperationLatestObserver(bool configureAwait)
{
this.cancellationTokenSource = new CancellationTokenSource();
this.configureAwait = configureAwait;
this.channel = ChannelUtility.CreateSingleReadeWriterSingularBounded<T>();

RunQueueWorker();
}

protected override sealed void OnNextCore(T value)
{
channel.Writer.TryWrite(value);
}

protected override sealed void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
PublishOnCompleted(result);
Dispose();
return;
}

Volatile.Write(ref completed, true);
channel.Writer.TryComplete(); // exit wait read loop
}

protected override sealed void DisposeCore()
{
channel.Writer.TryComplete(); // complete writing
cancellationTokenSource.Cancel(); // stop selector await.
}

protected abstract ValueTask OnNextAsync(T value, CancellationToken cancellationToken, bool configureAwait);
protected abstract void PublishOnCompleted(Result result);

async void RunQueueWorker() // don't(can't) wait so use async void
{
var reader = channel.Reader;
var token = cancellationTokenSource.Token;

try
{
while (await reader.WaitToReadAsync(/* don't pass CancellationToken, uses WriterComplete */).ConfigureAwait(configureAwait))
{
while (reader.TryRead(out var item))
{
try
{
if (token.IsCancellationRequested) return;

await OnNextAsync(item, token, configureAwait).ConfigureAwait(configureAwait);
}
catch (Exception ex)
{
if (ex is OperationCanceledException)
{
return;
}
OnErrorResume(ex);
}
}
}

if (Volatile.Read(ref completed))
{
PublishOnCompleted(Result.Success);
Dispose();
}
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
ObservableSystem.GetUnhandledExceptionHandler().Invoke(ex);
}
}
}
13 changes: 13 additions & 0 deletions src/R3/Internal/ChannelUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,21 @@ internal static class ChannelUtility
AllowSynchronousContinuations = true // if false, uses TaskCreationOptions.RunContinuationsAsynchronously so avoid it.
};

static readonly BoundedChannelOptions singularBoundedOptions = new BoundedChannelOptions(1)
{
SingleWriter = true, // in Rx operator, OnNext gurantees synchronous
SingleReader = true, // almostly uses single reader loop
AllowSynchronousContinuations = true, // if false, uses TaskCreationOptions.RunContinuationsAsynchronously so avoid it.
FullMode = BoundedChannelFullMode.DropOldest, // This will ensure that the latest item to come in is always added
};

internal static Channel<T> CreateSingleReadeWriterUnbounded<T>()
{
return Channel.CreateUnbounded<T>(options);
}

internal static Channel<T> CreateSingleReadeWriterSingularBounded<T>()
{
return Channel.CreateBounded<T>(singularBoundedOptions);
}
}
25 changes: 25 additions & 0 deletions src/R3/Operators/SelectAwait.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ protected override IDisposable SubscribeCore(Observer<TResult> observer)
if (maxConcurrent == 0 || maxConcurrent < -1) throw new ArgumentException("maxConcurrent must be a -1 or greater than 1.");
return source.Subscribe(new SelectAwaitSequentialParallelConcurrentLimit(observer, selector, configureAwait, maxConcurrent));
}
case AwaitOperation.Latest:
return source.Subscribe(new SelectAwaitLatest(observer, selector, configureAwait));
default:
throw new ArgumentException();
}
Expand Down Expand Up @@ -248,4 +250,27 @@ protected override void PublishOnCompleted(Result result)
observer.OnCompleted(result);
}
}

sealed class SelectAwaitLatest(Observer<TResult> observer, Func<T, CancellationToken, ValueTask<TResult>> selector, bool configureAwait)
: AwaitOperationLatestObserver<T>(configureAwait)
{
#if NET6_0_OR_GREATER
[AsyncMethodBuilderAttribute(typeof(PoolingAsyncValueTaskMethodBuilder))]
#endif
protected override async ValueTask OnNextAsync(T value, CancellationToken cancellationToken, bool configureAwait)
{
var v = await selector(value, cancellationToken).ConfigureAwait(configureAwait);
observer.OnNext(v);
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void PublishOnCompleted(Result result)
{
observer.OnCompleted(result);
}
}
}
21 changes: 21 additions & 0 deletions src/R3/Operators/SubscribeAwait.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public static IDisposable SubscribeAwait<T>(this Observable<T> source, Func<T, C
return source.Subscribe(new SubscribeAwaitSwitch<T>(onNextAsync, onErrorResume, onCompleted, configureAwait));
case AwaitOperation.SequentialParallel:
throw new ArgumentException("SubscribeAwait does not support SequentialParallel. Use Sequential for sequential operation, use parallel for parallel operation instead.");
case AwaitOperation.Latest:
return source.Subscribe(new SubscribeAwaitLatest<T>(onNextAsync, onErrorResume, onCompleted, configureAwait));
default:
throw new ArgumentException();
}
Expand Down Expand Up @@ -158,3 +160,22 @@ protected override void PublishOnCompleted(Result result)
}
}
}

internal sealed class SubscribeAwaitLatest<T>(Func<T, CancellationToken, ValueTask> onNextAsync, Action<Exception> onErrorResume, Action<Result> onCompleted, bool configureAwait)
: AwaitOperationLatestObserver<T>(configureAwait)
{
protected override ValueTask OnNextAsync(T value, CancellationToken cancellationToken, bool configureAwait)
{
return onNextAsync(value, cancellationToken);
}

protected override void OnErrorResumeCore(Exception error)
{
onErrorResume(error);
}

protected override void PublishOnCompleted(Result result)
{
onCompleted(result);
}
}
46 changes: 41 additions & 5 deletions src/R3/Operators/WhereAwait.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public static Observable<T> WhereAwait<T>(this Observable<T> source, Func<T, Can
}
}

internal sealed class WhereAwait<T>(Observable<T> source, Func<T, CancellationToken, ValueTask<bool>> predicate, AwaitOperation awaitOperations, bool configureAwait, int maxConcurrent) : Observable<T>
internal sealed class WhereAwait<T>(Observable<T> source, Func<T, CancellationToken, ValueTask<bool>> predicate, AwaitOperation awaitOperations, bool configureAwait, int maxConcurrent)
: Observable<T>
{
protected override IDisposable SubscribeCore(Observer<T> observer)
{
Expand All @@ -34,7 +35,7 @@ protected override IDisposable SubscribeCore(Observer<T> observer)
return source.Subscribe(new WhereAwaitParallelConcurrentLimit(observer, predicate, configureAwait, maxConcurrent));
}


case AwaitOperation.SequentialParallel:
if (maxConcurrent == -1)
{
Expand All @@ -45,12 +46,15 @@ protected override IDisposable SubscribeCore(Observer<T> observer)
if (maxConcurrent == 0 || maxConcurrent < -1) throw new ArgumentException("maxConcurrent must be a -1 or greater than 1.");
return source.Subscribe(new WhereAwaitSequentialParallelConcurrentLimit(observer, predicate, configureAwait, maxConcurrent));
}
case AwaitOperation.Latest:
return source.Subscribe(new WhereAwaitLatest(observer, predicate, configureAwait));
default:
throw new ArgumentException();
}
}

sealed class WhereAwaitSequential(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait) : AwaitOperationSequentialObserver<T>(configureAwait)
sealed class WhereAwaitSequential(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait)
: AwaitOperationSequentialObserver<T>(configureAwait)
{

#if NET6_0_OR_GREATER
Expand Down Expand Up @@ -78,7 +82,8 @@ protected override void PublishOnCompleted(Result result)
}
}

sealed class WhereAwaitDrop(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait) : AwaitOperationDropObserver<T>(configureAwait)
sealed class WhereAwaitDrop(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait)
: AwaitOperationDropObserver<T>(configureAwait)
{

#if NET6_0_OR_GREATER
Expand All @@ -103,7 +108,8 @@ protected override void PublishOnCompleted(Result result)
}
}

sealed class WhereAwaitParallel(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait) : AwaitOperationParallelObserver<T>(configureAwait)
sealed class WhereAwaitParallel(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait)
: AwaitOperationParallelObserver<T>(configureAwait)
{

#if NET6_0_OR_GREATER
Expand Down Expand Up @@ -268,4 +274,34 @@ protected override void PublishOnCompleted(Result result)
observer.OnCompleted(result);
}
}

sealed class WhereAwaitLatest(Observer<T> observer, Func<T, CancellationToken, ValueTask<bool>> predicate, bool configureAwait)
: AwaitOperationLatestObserver<T>(configureAwait)
{

#if NET6_0_OR_GREATER
[AsyncMethodBuilderAttribute(typeof(PoolingAsyncValueTaskMethodBuilder))]
#endif
protected override async ValueTask OnNextAsync(T value, CancellationToken cancellationToken, bool configureAwait)
{
if (await predicate(value, cancellationToken).ConfigureAwait(configureAwait))
{
if (!cancellationToken.IsCancellationRequested)
{
observer.OnNext(value);
}
}
}

protected override void OnErrorResumeCore(Exception error)
{
observer.OnErrorResume(error);
}

protected override void PublishOnCompleted(Result result)
{
observer.OnCompleted(result);
}
}

}
94 changes: 94 additions & 0 deletions tests/R3.Tests/OperatorTests/SelectAwaitTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -495,4 +495,98 @@ public void SequentialParallelLimit()

liveList.AssertIsCompleted();
}

[Fact]
public void Latest()
{
SynchronizationContext.SetSynchronizationContext(null); // xUnit insert fucking SynchronizationContext so ignore it.

var subject = new Subject<int>();
var timeProvider = new FakeTimeProvider();

using var liveList = subject
.SelectAwait(async (x, ct) =>
{
await Task.Delay(TimeSpan.FromSeconds(3), timeProvider, ct);
return x * 100;
}, AwaitOperation.Latest, configureAwait: false)
.ToLiveList();

subject.OnNext(1);
subject.OnNext(2);
subject.OnNext(3);
subject.OnNext(4);
subject.OnNext(5);

liveList.AssertEqual([]);

timeProvider.Advance(2);
liveList.AssertEqual([]);

timeProvider.Advance(1);
liveList.AssertEqual([100]);

timeProvider.Advance(3);
liveList.AssertEqual([100,500]);

subject.OnNext(6);
subject.OnNext(7);
subject.OnNext(8);
subject.OnNext(9);

timeProvider.Advance(1);
liveList.AssertEqual([100, 500]);

timeProvider.Advance(2);
liveList.AssertEqual([100, 500, 600]);

timeProvider.Advance(3);
liveList.AssertEqual([100, 500, 600, 900]);

subject.OnCompleted();

liveList.AssertIsCompleted();
}

[Fact]
public async Task LatestCancel()
{
SynchronizationContext.SetSynchronizationContext(null); // xUnit insert fucking SynchronizationContext so ignore it.

var subject = new Subject<int>();
var timeProvider = new FakeTimeProvider();

bool canceled = false;
using var liveList = subject
.SelectAwait(async (x, ct) =>
{
try
{
await Task.Delay(TimeSpan.FromSeconds(3), timeProvider, ct);
return x * 100;
}
catch (OperationCanceledException)
{
canceled = true;
throw;
}
}, AwaitOperation.Latest)
.ToLiveList();

subject.OnNext(1);
subject.OnNext(2);

liveList.AssertEqual([]);

timeProvider.Advance(3);
liveList.AssertEqual([100]);

canceled.Should().BeFalse();

liveList.Dispose();

await Task.Delay(TimeSpan.FromSeconds(1));

canceled.Should().BeTrue();
}
}
Loading

0 comments on commit f9273a5

Please sign in to comment.