-
-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #34 from Cysharp/hadashiA/separate-count
Add independent impl Count/LongCount
- Loading branch information
Showing
4 changed files
with
245 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
namespace R3; | ||
|
||
public static partial class ObservableExtensions | ||
{ | ||
public static Task<int> CountAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default) | ||
{ | ||
var method = new CountAsync<T>(cancellationToken); | ||
source.Subscribe(method); | ||
return method.Task; | ||
} | ||
|
||
public static Task<int> CountAsync<T>(this Observable<T> source, Func<T, bool> predicate, CancellationToken cancellationToken = default) | ||
{ | ||
var method = new CountFilterAsync<T>(predicate, cancellationToken); | ||
source.Subscribe(method); | ||
return method.Task; | ||
} | ||
|
||
public static Task<long> LongCountAsync<T>(this Observable<T> source, CancellationToken cancellationToken = default) | ||
{ | ||
var method = new LongCountAsync<T>(cancellationToken); | ||
source.Subscribe(method); | ||
return method.Task; | ||
} | ||
|
||
public static Task<long> LongCountAsync<T>(this Observable<T> source, Func<T, bool> predicate, CancellationToken cancellationToken = default) | ||
{ | ||
var method = new LongCountFilterAsync<T>(predicate, cancellationToken); | ||
source.Subscribe(method); | ||
return method.Task; | ||
} | ||
} | ||
|
||
internal sealed class CountAsync<T>(CancellationToken cancellationToken) : TaskObserverBase<T, int>(cancellationToken) | ||
{ | ||
int count; | ||
|
||
protected override void OnNextCore(T _) | ||
{ | ||
count = checked(count + 1); | ||
} | ||
|
||
protected override void OnErrorResumeCore(Exception error) | ||
{ | ||
TrySetException(error); | ||
} | ||
|
||
protected override void OnCompletedCore(Result result) | ||
{ | ||
if (result.IsFailure) | ||
{ | ||
TrySetException(result.Exception); | ||
return; | ||
} | ||
TrySetResult(count); | ||
} | ||
} | ||
|
||
internal sealed class CountFilterAsync<T>(Func<T, bool> predicate, CancellationToken cancellationToken) : TaskObserverBase<T, int>(cancellationToken) | ||
{ | ||
int count; | ||
|
||
protected override void OnNextCore(T value) | ||
{ | ||
if (predicate(value)) | ||
{ | ||
count = checked(count + 1); | ||
} | ||
} | ||
|
||
protected override void OnErrorResumeCore(Exception error) | ||
{ | ||
TrySetException(error); | ||
} | ||
|
||
protected override void OnCompletedCore(Result result) | ||
{ | ||
if (result.IsFailure) | ||
{ | ||
TrySetException(result.Exception); | ||
return; | ||
} | ||
TrySetResult(count); | ||
} | ||
} | ||
|
||
internal sealed class LongCountAsync<T>(CancellationToken cancellationToken) : TaskObserverBase<T, long>(cancellationToken) | ||
{ | ||
long count; | ||
|
||
protected override void OnNextCore(T _) | ||
{ | ||
count = checked(count + 1); | ||
} | ||
|
||
protected override void OnErrorResumeCore(Exception error) | ||
{ | ||
TrySetException(error); | ||
} | ||
|
||
protected override void OnCompletedCore(Result result) | ||
{ | ||
if (result.IsFailure) | ||
{ | ||
TrySetException(result.Exception); | ||
return; | ||
} | ||
TrySetResult(count); | ||
} | ||
} | ||
|
||
internal sealed class LongCountFilterAsync<T>(Func<T, bool> predicate, CancellationToken cancellationToken) : TaskObserverBase<T, long>(cancellationToken) | ||
{ | ||
long count; | ||
|
||
protected override void OnNextCore(T value) | ||
{ | ||
if (predicate(value)) | ||
{ | ||
count = checked(count + 1); | ||
} | ||
} | ||
|
||
protected override void OnErrorResumeCore(Exception error) | ||
{ | ||
TrySetException(error); | ||
} | ||
|
||
protected override void OnCompletedCore(Result result) | ||
{ | ||
if (result.IsFailure) | ||
{ | ||
TrySetException(result.Exception); | ||
return; | ||
} | ||
TrySetResult(count); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
namespace R3.Tests.OperatorTests; | ||
|
||
public class CountTest | ||
{ | ||
[Fact] | ||
public async Task Empty() | ||
{ | ||
(await Observable.Empty<int>().CountAsync()).Should().Be(0); | ||
(await Observable.Empty<int>().CountAsync(_ => true)).Should().Be(0); | ||
(await Observable.Empty<long>().LongCountAsync()).Should().Be(0); | ||
(await Observable.Empty<long>().LongCountAsync(_ => true)).Should().Be(0); | ||
} | ||
|
||
[Fact] | ||
public async Task MultipleValues() | ||
{ | ||
var source = new [] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable(); | ||
|
||
(await source.CountAsync()).Should().Be(8); | ||
(await source.Select(x => (long)x).CountAsync()).Should().Be(8); | ||
|
||
(await source.CountAsync(x => x % 2 == 0)).Should().Be(4); | ||
(await source.Select(x => (long)x).CountAsync(x => x % 2== 0)).Should().Be(4); | ||
} | ||
|
||
[Fact] | ||
public async Task Filter() | ||
{ | ||
var source = new [] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable(); | ||
|
||
(await source.CountAsync(x => x % 2 == 0)).Should().Be(4); | ||
(await source.Select(x => (long)x).CountAsync(x => x % 2== 0)).Should().Be(4); | ||
} | ||
|
||
[Fact] | ||
public async Task Error() | ||
{ | ||
var error = Observable.Range(0, 10).Select(x => | ||
{ | ||
if (x == 3) throw new Exception("foo"); | ||
return x; | ||
}); | ||
|
||
await Assert.ThrowsAsync<Exception>(async () => await error.CountAsync()); | ||
await Assert.ThrowsAsync<Exception>(async () => await error.LongCountAsync()); | ||
await Assert.ThrowsAsync<Exception>(async () => await error.OnErrorResumeAsFailure().CountAsync()); | ||
await Assert.ThrowsAsync<Exception>(async () => await error.OnErrorResumeAsFailure().LongCountAsync()); | ||
} | ||
|
||
[Fact] | ||
public async Task PredicateError() | ||
{ | ||
var source = new int[] { 1, 10, 1, 3, 4, 6, 7, 4 }.ToObservable(); | ||
|
||
await Assert.ThrowsAsync<Exception>(async () => | ||
{ | ||
await source.CountAsync(_ => throw new Exception("hoge")); | ||
}); | ||
await Assert.ThrowsAsync<Exception>(async () => | ||
{ | ||
await source.LongCountAsync(_ => throw new Exception("hoge")); | ||
}); | ||
} | ||
} |