Skip to content

Commit

Permalink
Merge pull request #32 from Cysharp/hadashiA/separate-minmax
Browse files Browse the repository at this point in the history
Add independent implementations of Min/Max
  • Loading branch information
neuecc authored Jan 12, 2024
2 parents 907f4f3 + 6c496c3 commit 8d26a50
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 77 deletions.
32 changes: 0 additions & 32 deletions src/R3/Operators/AggregateOperators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,38 +51,6 @@ public static Task<long> LongCountAsync<T>(this Observable<T> source, Cancellati
return AggregateAsync(source, 0L, static (count, _) => checked(count + 1), Stubs<long>.ReturnSelf, cancellationToken); // ignore complete
}

public static Task<T> MinAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
return AggregateAsync(source, (default(T)!, hasValue: false),
static (min, message) =>
{
if (!min.hasValue) return (message, true); // first
return Comparer<T>.Default.Compare(min.Item1, message) < 0 ? (min.Item1, true) : (message, true);
},
static (min) =>
{
if (!min.hasValue) throw new InvalidOperationException("Sequence contains no elements");
return min.Item1;
}, cancellationToken);
}


public static Task<T> MaxAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
return AggregateAsync(source, (default(T)!, hasValue: false),
static (max, message) =>
{
if (!max.hasValue) return (message, true); // first
return Comparer<T>.Default.Compare(max.Item1, message) > 0 ? (max.Item1, true) : (message, true);
},
static (max) =>
{
if (!max.hasValue) throw new InvalidOperationException("Sequence contains no elements");
return max.Item1;
}, cancellationToken);
}


public static Task<(T Min, T Max)> MinMaxAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
return AggregateAsync(source,
Expand Down
121 changes: 121 additions & 0 deletions src/R3/Operators/MaxAsync.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
namespace R3;

public static partial class ObservableExtensions
{
public static Task<T> MaxAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
var method = new MaxAsync<T>(Comparer<T>.Default, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<T> MaxAsync<T>(this Observable<T> source, IComparer<T> comparer, CancellationToken cancellationToken = default)
{
var method = new MaxAsync<T>(comparer, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<TResult> MaxAsync<TSource, TResult>(this Observable<TSource> source, Func<TSource, TResult> selector, CancellationToken cancellationToken = default)
{
var method = new MaxAsync<TSource, TResult>(selector, Comparer<TResult>.Default, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<TResult> MaxAsync<TSource, TResult>(this Observable<TSource> source, Func<TSource, TResult> selector, IComparer<TResult> comparer, CancellationToken cancellationToken = default)
{
var method = new MaxAsync<TSource, TResult>(selector, comparer, cancellationToken);
source.Subscribe(method);
return method.Task;
}
}

internal sealed class MaxAsync<T>(IComparer<T> comparer, CancellationToken cancellation) : TaskObserverBase<T, T>(cancellation)
{
T current = default!;
bool hasValue;

protected override void OnNextCore(T value)
{
if (!hasValue)
{
hasValue = true;
current = value;
return;
}

if (comparer.Compare(value, current) > 0)
{
current = value;
}
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}

if (hasValue)
{
TrySetResult(current);
}
else
{
TrySetException(new InvalidOperationException("Sequence contains no elements"));
}
}
}

internal sealed class MaxAsync<TSource, TResult>(Func<TSource, TResult> selector, IComparer<TResult> comparer, CancellationToken cancellation) : TaskObserverBase<TSource, TResult>(cancellation)
{
TResult current = default!;
bool hasValue;

protected override void OnNextCore(TSource value)
{
var nextValue = selector(value);
if (!hasValue)
{
hasValue = true;
current = nextValue;
return;
}

if (comparer.Compare(nextValue, current) > 0)
{
current = nextValue;
}
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}

if (hasValue)
{
TrySetResult(current);
}
else
{
TrySetException(new InvalidOperationException("Sequence contains no elements"));
}
}
}
121 changes: 121 additions & 0 deletions src/R3/Operators/MinAsync.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
namespace R3;

public static partial class ObservableExtensions
{
public static Task<T> MinAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
var method = new MinAsync<T>(Comparer<T>.Default, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<T> MinAsync<T>(this Observable<T> source, IComparer<T> comparer, CancellationToken cancellationToken = default)
{
var method = new MinAsync<T>(comparer, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<TResult> MinAsync<TSource, TResult>(this Observable<TSource> source, Func<TSource, TResult> selector, CancellationToken cancellationToken = default)
{
var method = new MinAsync<TSource, TResult>(selector, Comparer<TResult>.Default, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<TResult> MinAsync<TSource, TResult>(this Observable<TSource> source, Func<TSource, TResult> selector, IComparer<TResult> comparer, CancellationToken cancellationToken = default)
{
var method = new MinAsync<TSource, TResult>(selector, comparer, cancellationToken);
source.Subscribe(method);
return method.Task;
}
}

internal sealed class MinAsync<T>(IComparer<T> comparer, CancellationToken cancellation) : TaskObserverBase<T, T>(cancellation)
{
T current = default!;
bool hasValue;

protected override void OnNextCore(T value)
{
if (!hasValue)
{
hasValue = true;
current = value;
return;
}

if (comparer.Compare(value, current) < 0)
{
current = value;
}
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}

if (hasValue)
{
TrySetResult(current);
}
else
{
TrySetException(new InvalidOperationException("Sequence contains no elements"));
}
}
}

internal sealed class MinAsync<TSource, TResult>(Func<TSource, TResult> selector, IComparer<TResult> comparer, CancellationToken cancellation) : TaskObserverBase<TSource, TResult>(cancellation)
{
TResult current = default!;
bool hasValue;

protected override void OnNextCore(TSource value)
{
var nextValue = selector(value);
if (!hasValue)
{
hasValue = true;
current = nextValue;
return;
}

if (comparer.Compare(nextValue, current) < 0)
{
current = nextValue;
}
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}

if (hasValue)
{
TrySetResult(current);
}
else
{
TrySetException(new InvalidOperationException("Sequence contains no elements"));
}
}
}
46 changes: 1 addition & 45 deletions tests/R3.Tests/OperatorTests/AggregateTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace R3.Tests.OperatorTests;
public class AggregateTest
{
[Fact]
public async Task Aggreagte()
public async Task Aggregate()
{
var publisher = new Subject<int>();

Expand Down Expand Up @@ -90,50 +90,6 @@ public async Task LongCount()
await Assert.ThrowsAsync<Exception>(async () => await error.LongCountAsync());
}

[Fact]
public async Task Min()
{
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();
var min = await source.MinAsync();

min.Should().Be(1);

(await Observable.Return(999).MinAsync()).Should().Be(999);

var task = Observable.Empty<int>().MinAsync();

await Assert.ThrowsAsync<InvalidOperationException>(async () => await task);

var error = Observable.Range(1, 10).Select(x =>
{
if (x == 3) throw new Exception("foo");
return x;
}).OnErrorResumeAsFailure();
await Assert.ThrowsAsync<Exception>(async () => await error.MinAsync());
}

[Fact]
public async Task Max()
{
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();
var min = await source.MaxAsync();

min.Should().Be(10);

(await Observable.Return(999).MaxAsync()).Should().Be(999);

var task = Observable.Empty<int>().MaxAsync();

await Assert.ThrowsAsync<InvalidOperationException>(async () => await task);

var error = Observable.Range(1, 10).Select(x =>
{
if (x == 3) throw new Exception("foo");
return x;
}).OnErrorResumeAsFailure();
await Assert.ThrowsAsync<Exception>(async () => await error.MaxAsync());
}

[Fact]
public async Task MinMax()
{
Expand Down
Loading

0 comments on commit 8d26a50

Please sign in to comment.