Skip to content

Commit

Permalink
Merge pull request #34 from Cysharp/hadashiA/separate-count
Browse files Browse the repository at this point in the history
Add independent impl Count/LongCount
  • Loading branch information
neuecc authored Jan 12, 2024
2 parents 8093301 + 6ab1f84 commit d52aa42
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 12 deletions.
12 changes: 0 additions & 12 deletions src/R3/Operators/AggregateOperators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,6 @@ public static Task<HashSet<T>> ToHashSetAsync<T>(this Observable<T> source, IEqu
}, (value) => value, cancellationToken); // ignore complete
}

// CountAsync using AggregateAsync
public static Task<int> CountAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
return AggregateAsync(source, 0, static (count, _) => checked(count + 1), Stubs<int>.ReturnSelf, cancellationToken); // ignore complete
}

// LongCountAsync using AggregateAsync
public static Task<long> LongCountAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
return AggregateAsync(source, 0L, static (count, _) => checked(count + 1), Stubs<long>.ReturnSelf, cancellationToken); // ignore complete
}

public static Task<(T Min, T Max)> MinMaxAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
return AggregateAsync(source,
Expand Down
138 changes: 138 additions & 0 deletions src/R3/Operators/CountAsync.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
namespace R3;

public static partial class ObservableExtensions
{
public static Task<int> CountAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
var method = new CountAsync<T>(cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<int> CountAsync<T>(this Observable<T> source, Func<T, bool> predicate, CancellationToken cancellationToken = default)
{
var method = new CountFilterAsync<T>(predicate, cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<long> LongCountAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default)
{
var method = new LongCountAsync<T>(cancellationToken);
source.Subscribe(method);
return method.Task;
}

public static Task<long> LongCountAsync<T>(this Observable<T> source, Func<T, bool> predicate, CancellationToken cancellationToken = default)
{
var method = new LongCountFilterAsync<T>(predicate, cancellationToken);
source.Subscribe(method);
return method.Task;
}
}

internal sealed class CountAsync<T>(CancellationToken cancellationToken) : TaskObserverBase<T, int>(cancellationToken)
{
int count;

protected override void OnNextCore(T _)
{
count = checked(count + 1);
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}
TrySetResult(count);
}
}

internal sealed class CountFilterAsync<T>(Func<T, bool> predicate, CancellationToken cancellationToken) : TaskObserverBase<T, int>(cancellationToken)
{
int count;

protected override void OnNextCore(T value)
{
if (predicate(value))
{
count = checked(count + 1);
}
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}
TrySetResult(count);
}
}

internal sealed class LongCountAsync<T>(CancellationToken cancellationToken) : TaskObserverBase<T, long>(cancellationToken)
{
long count;

protected override void OnNextCore(T _)
{
count = checked(count + 1);
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}
TrySetResult(count);
}
}

internal sealed class LongCountFilterAsync<T>(Func<T, bool> predicate, CancellationToken cancellationToken) : TaskObserverBase<T, long>(cancellationToken)
{
long count;

protected override void OnNextCore(T value)
{
if (predicate(value))
{
count = checked(count + 1);
}
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}
TrySetResult(count);
}
}
43 changes: 43 additions & 0 deletions tests/R3.Tests/OperatorTests/AggregateTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,49 @@ public async Task ToHashSet()
}

[Fact]
public async Task Min()
{
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();
var min = await source.MinAsync();

min.Should().Be(1);

(await Observable.Return(999).MinAsync()).Should().Be(999);

var task = Observable.Empty<int>().MinAsync();

await Assert.ThrowsAsync<InvalidOperationException>(async () => await task);

var error = Observable.Range(1, 10).Select(x =>
{
if (x == 3) throw new Exception("foo");
return x;
}).OnErrorResumeAsFailure();
await Assert.ThrowsAsync<Exception>(async () => await error.MinAsync());
}

[Fact]
public async Task Max()
{
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();
var min = await source.MaxAsync();

min.Should().Be(10);

(await Observable.Return(999).MaxAsync()).Should().Be(999);

var task = Observable.Empty<int>().MaxAsync();

await Assert.ThrowsAsync<InvalidOperationException>(async () => await task);

var error = Observable.Range(1, 10).Select(x =>
{
if (x == 3) throw new Exception("foo");
return x;
}).OnErrorResumeAsFailure();
await Assert.ThrowsAsync<Exception>(async () => await error.MaxAsync());
}

public async Task Count()
{
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();
Expand Down
64 changes: 64 additions & 0 deletions tests/R3.Tests/OperatorTests/CountTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
namespace R3.Tests.OperatorTests;

public class CountTest
{
[Fact]
public async Task Empty()
{
(await Observable.Empty<int>().CountAsync()).Should().Be(0);
(await Observable.Empty<int>().CountAsync(_ => true)).Should().Be(0);
(await Observable.Empty<long>().LongCountAsync()).Should().Be(0);
(await Observable.Empty<long>().LongCountAsync(_ => true)).Should().Be(0);
}

[Fact]
public async Task MultipleValues()
{
var source = new [] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();

(await source.CountAsync()).Should().Be(8);
(await source.Select(x => (long)x).CountAsync()).Should().Be(8);

(await source.CountAsync(x => x % 2 == 0)).Should().Be(4);
(await source.Select(x => (long)x).CountAsync(x => x % 2== 0)).Should().Be(4);
}

[Fact]
public async Task Filter()
{
var source = new [] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();

(await source.CountAsync(x => x % 2 == 0)).Should().Be(4);
(await source.Select(x => (long)x).CountAsync(x => x % 2== 0)).Should().Be(4);
}

[Fact]
public async Task Error()
{
var error = Observable.Range(0, 10).Select(x =>
{
if (x == 3) throw new Exception("foo");
return x;
});

await Assert.ThrowsAsync<Exception>(async () => await error.CountAsync());
await Assert.ThrowsAsync<Exception>(async () => await error.LongCountAsync());
await Assert.ThrowsAsync<Exception>(async () => await error.OnErrorResumeAsFailure().CountAsync());
await Assert.ThrowsAsync<Exception>(async () => await error.OnErrorResumeAsFailure().LongCountAsync());
}

[Fact]
public async Task PredicateError()
{
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable();

await Assert.ThrowsAsync<Exception>(async () =>
{
await source.CountAsync(_ => throw new Exception("hoge"));
});
await Assert.ThrowsAsync<Exception>(async () =>
{
await source.LongCountAsync(_ => throw new Exception("hoge"));
});
}
}

0 comments on commit d52aa42

Please sign in to comment.