diff --git a/src/libraries/System.Threading.Tasks.Parallel/ref/System.Threading.Tasks.Parallel.cs b/src/libraries/System.Threading.Tasks.Parallel/ref/System.Threading.Tasks.Parallel.cs index 72f6a5b6071b0..41be9b17b088c 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/ref/System.Threading.Tasks.Parallel.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/ref/System.Threading.Tasks.Parallel.cs @@ -36,6 +36,12 @@ public static partial class Parallel public static System.Threading.Tasks.ParallelLoopResult ForEach(System.Collections.Generic.IEnumerable source, System.Func localInit, System.Func body, System.Action localFinally) { throw null; } public static System.Threading.Tasks.ParallelLoopResult ForEach(System.Collections.Generic.IEnumerable source, System.Threading.Tasks.ParallelOptions parallelOptions, System.Func localInit, System.Func body, System.Action localFinally) { throw null; } public static System.Threading.Tasks.ParallelLoopResult ForEach(System.Collections.Generic.IEnumerable source, System.Threading.Tasks.ParallelOptions parallelOptions, System.Func localInit, System.Func body, System.Action localFinally) { throw null; } + public static System.Threading.Tasks.Task ForEachAsync(System.Collections.Generic.IEnumerable source, System.Func body) { throw null; } + public static System.Threading.Tasks.Task ForEachAsync(System.Collections.Generic.IEnumerable source, CancellationToken cancellationToken, System.Func body) { throw null; } + public static System.Threading.Tasks.Task ForEachAsync(System.Collections.Generic.IEnumerable source, System.Threading.Tasks.ParallelOptions parallelOptions, System.Func body) { throw null; } + public static System.Threading.Tasks.Task ForEachAsync(System.Collections.Generic.IAsyncEnumerable source, System.Func body) { throw null; } + public static System.Threading.Tasks.Task ForEachAsync(System.Collections.Generic.IAsyncEnumerable source, CancellationToken cancellationToken, System.Func body) { throw null; } + public static System.Threading.Tasks.Task ForEachAsync(System.Collections.Generic.IAsyncEnumerable source, System.Threading.Tasks.ParallelOptions parallelOptions, System.Func body) { throw null; } public static System.Threading.Tasks.ParallelLoopResult For(int fromInclusive, int toExclusive, System.Func localInit, System.Func body, System.Action localFinally) { throw null; } public static System.Threading.Tasks.ParallelLoopResult For(int fromInclusive, int toExclusive, System.Threading.Tasks.ParallelOptions parallelOptions, System.Func localInit, System.Func body, System.Action localFinally) { throw null; } public static System.Threading.Tasks.ParallelLoopResult For(long fromInclusive, long toExclusive, System.Func localInit, System.Func body, System.Action localFinally) { throw null; } diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/Resources/Strings.resx b/src/libraries/System.Threading.Tasks.Parallel/src/Resources/Strings.resx index e9454fdaae755..e923e39df7e2b 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/Resources/Strings.resx +++ b/src/libraries/System.Threading.Tasks.Parallel/src/Resources/Strings.resx @@ -64,13 +64,13 @@ This method requires the use of an OrderedPartitioner with the KeysNormalized property set to true. - The Partitioner used here must support dynamic partitioning. + The Partitioner must support dynamic partitioning. - The Partitioner used here returned a null partitioner source. + The Partitioner returned a null partitioner source. - The Partitioner source returned a null enumerator. + The source returned a null enumerator. Break was called after Stop was called. @@ -81,4 +81,4 @@ This method is not supported. - \ No newline at end of file + diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System.Threading.Tasks.Parallel.csproj b/src/libraries/System.Threading.Tasks.Parallel/src/System.Threading.Tasks.Parallel.csproj index 3d0ccd82d99c3..84566e7f073db 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/System.Threading.Tasks.Parallel.csproj +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System.Threading.Tasks.Parallel.csproj @@ -7,17 +7,20 @@ + + + diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs new file mode 100644 index 0000000000000..ee3b794e09dfa --- /dev/null +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs @@ -0,0 +1,559 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Threading.Tasks +{ + public static partial class Parallel + { + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + /// The operation will execute at most operations in parallel. + public static Task ForEachAsync(IEnumerable source, Func body) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + if (body is null) + { + throw new ArgumentNullException(nameof(body)); + } + + return ForEachAsync(source, DefaultDegreeOfParallelism, TaskScheduler.Default, default(CancellationToken), body); + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// A cancellation token that may be used to cancel the for each operation. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + /// The operation will execute at most operations in parallel. + public static Task ForEachAsync(IEnumerable source, CancellationToken cancellationToken, Func body) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + if (body is null) + { + throw new ArgumentNullException(nameof(body)); + } + + return ForEachAsync(source, DefaultDegreeOfParallelism, TaskScheduler.Default, cancellationToken, body); + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// An object that configures the behavior of this operation. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + public static Task ForEachAsync(IEnumerable source, ParallelOptions parallelOptions, Func body) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + if (parallelOptions is null) + { + throw new ArgumentNullException(nameof(parallelOptions)); + } + if (body is null) + { + throw new ArgumentNullException(nameof(body)); + } + + return ForEachAsync(source, parallelOptions.EffectiveMaxConcurrencyLevel, parallelOptions.EffectiveTaskScheduler, parallelOptions.CancellationToken, body); + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// A integer indicating how many operations to allow to run in parallel. + /// The task scheduler on which all code should execute. + /// A cancellation token that may be used to cancel the for each operation. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + private static Task ForEachAsync(IEnumerable source, int dop, TaskScheduler scheduler, CancellationToken cancellationToken, Func body) + { + Debug.Assert(source != null); + Debug.Assert(scheduler != null); + Debug.Assert(body != null); + + // One fast up-front check for cancellation before we start the whole operation. + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + if (dop < 0) + { + dop = DefaultDegreeOfParallelism; + } + + // The worker body. Each worker will execute this same body. + Func taskBody = static async o => + { + var state = (SyncForEachAsyncState)o; + bool launchedNext = false; + +#pragma warning disable CA2007 // Explicitly don't use ConfigureAwait, as we want to perform all work on the specified scheduler that's now current + try + { + // Continue to loop while there are more elements to be processed. + while (!state.Cancellation.IsCancellationRequested) + { + // Get the next element from the enumerator. This requires asynchronously locking around MoveNextAsync/Current. + TSource element; + lock (state) + { + if (!state.Enumerator.MoveNext()) + { + break; + } + + element = state.Enumerator.Current; + } + + // If the remaining dop allows it and we've not yet queued the next worker, do so now. We wait + // until after we've grabbed an item from the enumerator to a) avoid unnecessary contention on the + // serialized resource, and b) avoid queueing another work if there aren't any more items. Each worker + // is responsible only for creating the next worker, which in turn means there can't be any contention + // on creating workers (though it's possible one worker could be executing while we're creating the next). + if (!launchedNext) + { + launchedNext = true; + state.QueueWorkerIfDopAvailable(); + } + + // Process the loop body. + await state.LoopBody(element, state.Cancellation.Token); + } + } + catch (Exception e) + { + // Record the failure and then don't let the exception propagate. The last worker to complete + // will propagate exceptions as is appropriate to the top-level task. + state.RecordException(e); + } + finally + { + // If we're the last worker to complete, clean up and complete the operation. + if (state.SignalWorkerCompletedIterating()) + { + try + { + state.Dispose(); + } + catch (Exception e) + { + state.RecordException(e); + } + + // Finally, complete the task returned to the ForEachAsync caller. + // This must be the very last thing done. + state.Complete(); + } + } +#pragma warning restore CA2007 + }; + + try + { + // Construct a state object that encapsulates all state to be passed and shared between + // the workers, and queues the first worker. + var state = new SyncForEachAsyncState(source, taskBody, dop, scheduler, cancellationToken, body); + state.QueueWorkerIfDopAvailable(); + return state.Task; + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + /// The operation will execute at most operations in parallel. + public static Task ForEachAsync(IAsyncEnumerable source, Func body) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + if (body is null) + { + throw new ArgumentNullException(nameof(body)); + } + + return ForEachAsync(source, DefaultDegreeOfParallelism, TaskScheduler.Default, default(CancellationToken), body); + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// A cancellation token that may be used to cancel the for each operation. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + /// The operation will execute at most operations in parallel. + public static Task ForEachAsync(IAsyncEnumerable source, CancellationToken cancellationToken, Func body) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + if (body is null) + { + throw new ArgumentNullException(nameof(body)); + } + + return ForEachAsync(source, DefaultDegreeOfParallelism, TaskScheduler.Default, cancellationToken, body); + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// An object that configures the behavior of this operation. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + public static Task ForEachAsync(IAsyncEnumerable source, ParallelOptions parallelOptions, Func body) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } + if (parallelOptions is null) + { + throw new ArgumentNullException(nameof(parallelOptions)); + } + if (body is null) + { + throw new ArgumentNullException(nameof(body)); + } + + return ForEachAsync(source, parallelOptions.EffectiveMaxConcurrencyLevel, parallelOptions.EffectiveTaskScheduler, parallelOptions.CancellationToken, body); + } + + /// Executes a for each operation on an in which iterations may run in parallel. + /// The type of the data in the source. + /// An enumerable data source. + /// A integer indicating how many operations to allow to run in parallel. + /// The task scheduler on which all code should execute. + /// A cancellation token that may be used to cancel the for each operation. + /// An asynchronous delegate that is invoked once per element in the data source. + /// The exception that is thrown when the argument or argument is null. + /// A task that represents the entire for each operation. + private static Task ForEachAsync(IAsyncEnumerable source, int dop, TaskScheduler scheduler, CancellationToken cancellationToken, Func body) + { + Debug.Assert(source != null); + Debug.Assert(scheduler != null); + Debug.Assert(body != null); + + // One fast up-front check for cancellation before we start the whole operation. + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + if (dop < 0) + { + dop = DefaultDegreeOfParallelism; + } + + // The worker body. Each worker will execute this same body. + Func taskBody = static async o => + { + var state = (AsyncForEachAsyncState)o; + bool launchedNext = false; + +#pragma warning disable CA2007 // Explicitly don't use ConfigureAwait, as we want to perform all work on the specified scheduler that's now current + try + { + // Continue to loop while there are more elements to be processed. + while (!state.Cancellation.IsCancellationRequested) + { + // Get the next element from the enumerator. This requires asynchronously locking around MoveNextAsync/Current. + TSource element; + await state.Lock.WaitAsync(state.Cancellation.Token); + try + { + if (!await state.Enumerator.MoveNextAsync()) + { + break; + } + + element = state.Enumerator.Current; + } + finally + { + state.Lock.Release(); + } + + // If the remaining dop allows it and we've not yet queued the next worker, do so now. We wait + // until after we've grabbed an item from the enumerator to a) avoid unnecessary contention on the + // serialized resource, and b) avoid queueing another work if there aren't any more items. Each worker + // is responsible only for creating the next worker, which in turn means there can't be any contention + // on creating workers (though it's possible one worker could be executing while we're creating the next). + if (!launchedNext) + { + launchedNext = true; + state.QueueWorkerIfDopAvailable(); + } + + // Process the loop body. + await state.LoopBody(element, state.Cancellation.Token); + } + } + catch (Exception e) + { + // Record the failure and then don't let the exception propagate. The last worker to complete + // will propagate exceptions as is appropriate to the top-level task. + state.RecordException(e); + } + finally + { + // If we're the last worker to complete, clean up and complete the operation. + if (state.SignalWorkerCompletedIterating()) + { + try + { + await state.DisposeAsync(); + } + catch (Exception e) + { + state.RecordException(e); + } + + // Finally, complete the task returned to the ForEachAsync caller. + // This must be the very last thing done. + state.Complete(); + } + } +#pragma warning restore CA2007 + }; + + try + { + // Construct a state object that encapsulates all state to be passed and shared between + // the workers, and queues the first worker. + var state = new AsyncForEachAsyncState(source, taskBody, dop, scheduler, cancellationToken, body); + state.QueueWorkerIfDopAvailable(); + return state.Task; + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + /// Gets the default degree of parallelism to use when none is explicitly provided. + private static int DefaultDegreeOfParallelism => Environment.ProcessorCount; + + /// Stores the state associated with a ForEachAsync operation, shared between all its workers. + /// Specifies the type of data being enumerated. + private abstract class ForEachAsyncState : TaskCompletionSource, IThreadPoolWorkItem + { + /// Registration with caller-provided cancellation token. + protected readonly CancellationTokenRegistration _registration; + /// + /// The delegate to invoke on each worker to run the enumerator processing loop. + /// + /// + /// This could have been an action rather than a func, but it returns a task so that the task body is an async Task + /// method rather than async void, even though the worker body catches all exceptions and the returned Task is ignored. + /// + private readonly Func _taskBody; + /// The on which all work should be performed. + private readonly TaskScheduler _scheduler; + /// The present at the time of the ForEachAsync invocation. This is only used if on the default scheduler. + private readonly ExecutionContext? _executionContext; + + /// The number of outstanding workers. When this hits 0, the operation has completed. + private int _completionRefCount; + /// Any exceptions incurred during execution. + private List? _exceptions; + /// The number of workers that may still be created. + private int _remainingDop; + + /// The delegate to invoke for each element yielded by the enumerator. + public readonly Func LoopBody; + /// The internal token source used to cancel pending work. + public readonly CancellationTokenSource Cancellation = new CancellationTokenSource(); + + /// Initializes the state object. + protected ForEachAsyncState(Func taskBody, int dop, TaskScheduler scheduler, CancellationToken cancellationToken, Func body) + { + _taskBody = taskBody; + _remainingDop = dop; + LoopBody = body; + _scheduler = scheduler; + if (scheduler == TaskScheduler.Default) + { + _executionContext = ExecutionContext.Capture(); + } + + _registration = cancellationToken.UnsafeRegister(static o => ((ForEachAsyncState)o!).Cancellation.Cancel(), this); + } + + /// Queues another worker if allowed by the remaining degree of parallelism permitted. + /// This is not thread-safe and must only be invoked by one worker at a time. + public void QueueWorkerIfDopAvailable() + { + if (_remainingDop > 0) + { + _remainingDop--; + + // Queue the invocation of the worker/task body. Note that we explicitly do not pass a cancellation token here, + // as the task body is what's responsible for completing the ForEachAsync task, for decrementing the reference count + // on pending tasks, and for cleaning up state. If a token were passed to StartNew (which simply serves to stop the + // task from starting to execute if it hasn't yet by the time cancellation is requested), all of that logic could be + // skipped, and bad things could ensue, e.g. deadlocks, leaks, etc. Also note that we need to increment the pending + // work item ref count prior to queueing the worker in order to avoid race conditions that could lead to temporarily + // and erroneously bouncing at zero, which would trigger completion too early. + Interlocked.Increment(ref _completionRefCount); + if (_scheduler == TaskScheduler.Default) + { + // If the scheduler is the default, we can avoid the overhead of the StartNew Task by just queueing + // this state object as the work item. + ThreadPool.UnsafeQueueUserWorkItem(this, preferLocal: false); + } + else + { + // We're targeting a non-default TaskScheduler, so queue the task body to it. + Task.Factory.StartNew(_taskBody!, this, default(CancellationToken), TaskCreationOptions.DenyChildAttach, _scheduler); + } + } + } + + /// Signals that the worker has completed iterating. + /// true if this is the last worker to complete iterating; otherwise, false. + public bool SignalWorkerCompletedIterating() => Interlocked.Decrement(ref _completionRefCount) == 0; + + /// Stores an exception and triggers cancellation in order to alert all workers to stop as soon as possible. + /// The exception. + public void RecordException(Exception e) + { + lock (this) + { + (_exceptions ??= new List()).Add(e); + } + + Cancellation.Cancel(); + } + + /// Completes the ForEachAsync task based on the status of this state object. + public void Complete() + { + Debug.Assert(_completionRefCount == 0, $"Expected {nameof(_completionRefCount)} == 0, got {_completionRefCount}"); + + bool taskSet; + if (_registration.Token.IsCancellationRequested) + { + // The externally provided token had cancellation requested. Assume that any exceptions + // then are due to that, and just cancel the resulting task. + taskSet = TrySetCanceled(_registration.Token); + } + else if (_exceptions is null) + { + // Everything completed successfully. + taskSet = TrySetResult(); + } + else + { + // Fault with all of the received exceptions, but filter out those due to inner cancellation, + // as they're effectively an implementation detail and stem from the original exception. + Debug.Assert(_exceptions.Count > 0, "If _exceptions was created, it should have also been populated."); + for (int i = 0; i < _exceptions.Count; i++) + { + if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token) + { + _exceptions[i] = null!; + } + } + _exceptions.RemoveAll(e => e is null); + Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation."); + taskSet = TrySetException(_exceptions); + } + + Debug.Assert(taskSet, "Complete should only be called once."); + } + + /// Executes the task body using the captured when ForEachAsync was invoked. + void IThreadPoolWorkItem.Execute() + { + Debug.Assert(_scheduler == TaskScheduler.Default, $"Expected {nameof(_scheduler)} == TaskScheduler.Default, got {_scheduler}"); + + if (_executionContext is null) + { + _taskBody(this); + } + else + { + ExecutionContext.Run(_executionContext, static o => ((ForEachAsyncState)o!)._taskBody(o), this); + } + } + } + + /// Stores the state associated with an IEnumerable ForEachAsync operation, shared between all its workers. + /// Specifies the type of data being enumerated. + private sealed class SyncForEachAsyncState : ForEachAsyncState, IDisposable + { + public readonly IEnumerator Enumerator; + + public SyncForEachAsyncState( + IEnumerable source, Func taskBody, + int dop, TaskScheduler scheduler, CancellationToken cancellationToken, + Func body) : + base(taskBody, dop, scheduler, cancellationToken, body) + { + Enumerator = source.GetEnumerator() ?? throw new InvalidOperationException(SR.Parallel_ForEach_NullEnumerator); + } + + public void Dispose() + { + _registration.Dispose(); + Enumerator.Dispose(); + } + } + + /// Stores the state associated with an IAsyncEnumerable ForEachAsync operation, shared between all its workers. + /// Specifies the type of data being enumerated. + private sealed class AsyncForEachAsyncState : ForEachAsyncState, IAsyncDisposable + { + public readonly SemaphoreSlim Lock = new SemaphoreSlim(1, 1); + public readonly IAsyncEnumerator Enumerator; + + public AsyncForEachAsyncState( + IAsyncEnumerable source, Func taskBody, + int dop, TaskScheduler scheduler, CancellationToken cancellationToken, + Func body) : + base(taskBody, dop, scheduler, cancellationToken, body) + { + Enumerator = source.GetAsyncEnumerator(Cancellation.Token) ?? throw new InvalidOperationException(SR.Parallel_ForEach_NullEnumerator); + } + + public ValueTask DisposeAsync() + { + _registration.Dispose(); + return Enumerator.DisposeAsync(); + } + } + } +} diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs index 09f40edb97d98..cffded009d03c 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs @@ -126,7 +126,7 @@ internal int EffectiveMaxConcurrencyLevel /// The class provides library-based data parallel replacements /// for common operations such as for loops, for each loops, and execution of a set of statements. /// - public static class Parallel + public static partial class Parallel { // static counter for generating unique Fork/Join Context IDs to be used in ETW events internal static int s_forkJoinContextID; diff --git a/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs b/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs new file mode 100644 index 0000000000000..e8cc5acac9a35 --- /dev/null +++ b/src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs @@ -0,0 +1,997 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using Xunit; + +namespace System.Threading.Tasks.Tests +{ + public sealed class ParallelForEachAsyncTests + { + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void InvalidArguments_ThrowsException() + { + AssertExtensions.Throws("source", () => { Parallel.ForEachAsync((IEnumerable)null, (item, cancellationToken) => default); }); + AssertExtensions.Throws("source", () => { Parallel.ForEachAsync((IEnumerable)null, CancellationToken.None, (item, cancellationToken) => default); }); + AssertExtensions.Throws("source", () => { Parallel.ForEachAsync((IEnumerable)null, new ParallelOptions(), (item, cancellationToken) => default); }); + + AssertExtensions.Throws("source", () => { Parallel.ForEachAsync((IAsyncEnumerable)null, (item, cancellationToken) => default); }); + AssertExtensions.Throws("source", () => { Parallel.ForEachAsync((IAsyncEnumerable)null, CancellationToken.None, (item, cancellationToken) => default); }); + AssertExtensions.Throws("source", () => { Parallel.ForEachAsync((IAsyncEnumerable)null, new ParallelOptions(), (item, cancellationToken) => default); }); + + AssertExtensions.Throws("parallelOptions", () => { Parallel.ForEachAsync(Enumerable.Range(1, 10), null, (item, cancellationToken) => default); }); + AssertExtensions.Throws("parallelOptions", () => { Parallel.ForEachAsync(EnumerableRangeAsync(1, 10), null, (item, cancellationToken) => default); }); + + AssertExtensions.Throws("body", () => { Parallel.ForEachAsync(Enumerable.Range(1, 10), null); }); + AssertExtensions.Throws("body", () => { Parallel.ForEachAsync(Enumerable.Range(1, 10), CancellationToken.None, null); }); + AssertExtensions.Throws("body", () => { Parallel.ForEachAsync(Enumerable.Range(1, 10), new ParallelOptions(), null); }); + + AssertExtensions.Throws("body", () => { Parallel.ForEachAsync(EnumerableRangeAsync(1, 10), null); }); + AssertExtensions.Throws("body", () => { Parallel.ForEachAsync(EnumerableRangeAsync(1, 10), CancellationToken.None, null); }); + AssertExtensions.Throws("body", () => { Parallel.ForEachAsync(EnumerableRangeAsync(1, 10), new ParallelOptions(), null); }); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void PreCanceled_CancelsSynchronously() + { + var box = new StrongBox(false); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + void AssertCanceled(Task t) + { + Assert.True(t.IsCanceled); + var oce = Assert.ThrowsAny(() => t.GetAwaiter().GetResult()); + Assert.Equal(cts.Token, oce.CancellationToken); + } + + Func body = (item, cancellationToken) => + { + Assert.False(true, "Should not have been invoked"); + return default; + }; + + AssertCanceled(Parallel.ForEachAsync(MarkStart(box), cts.Token, body)); + AssertCanceled(Parallel.ForEachAsync(MarkStartAsync(box), cts.Token, body)); + + AssertCanceled(Parallel.ForEachAsync(MarkStart(box), new ParallelOptions { CancellationToken = cts.Token }, body)); + AssertCanceled(Parallel.ForEachAsync(MarkStartAsync(box), new ParallelOptions { CancellationToken = cts.Token }, body)); + + Assert.False(box.Value); + + static IEnumerable MarkStart(StrongBox box) + { + Assert.False(box.Value); + box.Value = true; + yield return 0; + } + + static async IAsyncEnumerable MarkStartAsync(StrongBox box) + { + Assert.False(box.Value); + box.Value = true; + yield return 0; + + await Task.Yield(); + } + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(-1)] + [InlineData(1)] + [InlineData(2)] + [InlineData(4)] + [InlineData(128)] + public async Task Dop_WorkersCreatedRespectingLimit_Sync(int dop) + { + static IEnumerable IterateUntilSet(StrongBox box) + { + int counter = 0; + while (!box.Value) + { + yield return counter++; + } + } + + var box = new StrongBox(false); + + int activeWorkers = 0; + var block = new TaskCompletionSource(); + + Task t = Parallel.ForEachAsync(IterateUntilSet(box), new ParallelOptions { MaxDegreeOfParallelism = dop }, async (item, cancellationToken) => + { + Interlocked.Increment(ref activeWorkers); + await block.Task; + }); + Assert.False(t.IsCompleted); + + await Task.Delay(20); // give the loop some time to run + + box.Value = true; + block.SetResult(); + await t; + + Assert.InRange(activeWorkers, 0, dop == -1 ? Environment.ProcessorCount : dop); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(-1)] + [InlineData(1)] + [InlineData(2)] + [InlineData(4)] + [InlineData(128)] + public async Task Dop_WorkersCreatedRespectingLimitAndTaskScheduler_Sync(int dop) + { + static IEnumerable IterateUntilSet(StrongBox box) + { + int counter = 0; + while (!box.Value) + { + yield return counter++; + } + } + + var box = new StrongBox(false); + + int activeWorkers = 0; + var block = new TaskCompletionSource(); + + const int MaxSchedulerLimit = 2; + + Task t = Parallel.ForEachAsync(IterateUntilSet(box), new ParallelOptions { MaxDegreeOfParallelism = dop, TaskScheduler = new MaxConcurrencyLevelPassthroughTaskScheduler(MaxSchedulerLimit) }, async (item, cancellationToken) => + { + Interlocked.Increment(ref activeWorkers); + await block.Task; + }); + Assert.False(t.IsCompleted); + + await Task.Delay(20); // give the loop some time to run + + box.Value = true; + block.SetResult(); + await t; + + Assert.InRange(activeWorkers, 0, Math.Min(MaxSchedulerLimit, dop == -1 ? Environment.ProcessorCount : dop)); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Dop_NegativeTaskSchedulerLimitTreatedAsDefault_Sync() + { + static IEnumerable IterateUntilSet(StrongBox box) + { + int counter = 0; + while (!box.Value) + { + yield return counter++; + } + } + + var box = new StrongBox(false); + + int activeWorkers = 0; + var block = new TaskCompletionSource(); + + Task t = Parallel.ForEachAsync(IterateUntilSet(box), new ParallelOptions { TaskScheduler = new MaxConcurrencyLevelPassthroughTaskScheduler(-42) }, async (item, cancellationToken) => + { + Interlocked.Increment(ref activeWorkers); + await block.Task; + }); + Assert.False(t.IsCompleted); + + await Task.Delay(20); // give the loop some time to run + + box.Value = true; + block.SetResult(); + await t; + + Assert.InRange(activeWorkers, 0, Environment.ProcessorCount); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Dop_NegativeTaskSchedulerLimitTreatedAsDefault_Async() + { + static async IAsyncEnumerable IterateUntilSet(StrongBox box) + { + int counter = 0; + while (!box.Value) + { + await Task.Yield(); + yield return counter++; + } + } + + var box = new StrongBox(false); + + int activeWorkers = 0; + var block = new TaskCompletionSource(); + + Task t = Parallel.ForEachAsync(IterateUntilSet(box), new ParallelOptions { TaskScheduler = new MaxConcurrencyLevelPassthroughTaskScheduler(-42) }, async (item, cancellationToken) => + { + Interlocked.Increment(ref activeWorkers); + await block.Task; + }); + Assert.False(t.IsCompleted); + + await Task.Delay(20); // give the loop some time to run + + box.Value = true; + block.SetResult(); + await t; + + Assert.InRange(activeWorkers, 0, Environment.ProcessorCount); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task RunsAsynchronously_EvenForEntirelySynchronousWork_Sync() + { + static IEnumerable Iterate() + { + while (true) yield return 0; + } + + var cts = new CancellationTokenSource(); + + Task t = Parallel.ForEachAsync(Iterate(), cts.Token, (item, cancellationToken) => default); + Assert.False(t.IsCompleted); + + cts.Cancel(); + + await Assert.ThrowsAnyAsync(() => t); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task RunsAsynchronously_EvenForEntirelySynchronousWork_Async() + { +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + static async IAsyncEnumerable IterateAsync() +#pragma warning restore CS1998 + { + while (true) yield return 0; + } + + var cts = new CancellationTokenSource(); + + Task t = Parallel.ForEachAsync(IterateAsync(), cts.Token, (item, cancellationToken) => default); + Assert.False(t.IsCompleted); + + cts.Cancel(); + + await Assert.ThrowsAnyAsync(() => t); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(-1)] + [InlineData(1)] + [InlineData(2)] + [InlineData(4)] + [InlineData(128)] + public async Task Dop_WorkersCreatedRespectingLimit_Async(int dop) + { + static async IAsyncEnumerable IterateUntilSetAsync(StrongBox box) + { + int counter = 0; + while (!box.Value) + { + await Task.Yield(); + yield return counter++; + } + } + + var box = new StrongBox(false); + + int activeWorkers = 0; + var block = new TaskCompletionSource(); + + Task t = Parallel.ForEachAsync(IterateUntilSetAsync(box), new ParallelOptions { MaxDegreeOfParallelism = dop }, async (item, cancellationToken) => + { + Interlocked.Increment(ref activeWorkers); + await block.Task; + }); + Assert.False(t.IsCompleted); + + await Task.Delay(20); // give the loop some time to run + + box.Value = true; + block.SetResult(); + await t; + + Assert.InRange(activeWorkers, 0, dop == -1 ? Environment.ProcessorCount : dop); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task EmptySource_Sync() + { + int counter = 0; + await Parallel.ForEachAsync(Enumerable.Range(0, 0), (item, cancellationToken) => + { + Interlocked.Increment(ref counter); + return default; + }); + + Assert.Equal(0, counter); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task EmptySource_Async() + { + int counter = 0; + await Parallel.ForEachAsync(EnumerableRangeAsync(0, 0), (item, cancellationToken) => + { + Interlocked.Increment(ref counter); + return default; + }); + + Assert.Equal(0, counter); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false)] + [InlineData(true)] + public async Task AllItemsEnumeratedOnce_Sync(bool yield) + { + const int Start = 10, Count = 100; + + var set = new HashSet(); + + await Parallel.ForEachAsync(Enumerable.Range(Start, Count), async (item, cancellationToken) => + { + lock (set) + { + Assert.True(set.Add(item)); + } + + if (yield) + { + await Task.Yield(); + } + }); + + for (int i = Start; i < Start + Count; i++) + { + Assert.True(set.Contains(i)); + } + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false)] + [InlineData(true)] + public async Task AllItemsEnumeratedOnce_Async(bool yield) + { + const int Start = 10, Count = 100; + + var set = new HashSet(); + + await Parallel.ForEachAsync(EnumerableRangeAsync(Start, Count, yield), async (item, cancellationToken) => + { + lock (set) + { + Assert.True(set.Add(item)); + } + + if (yield) + { + await Task.Yield(); + } + }); + + for (int i = Start; i < Start + Count; i++) + { + Assert.True(set.Contains(i)); + } + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false)] + [InlineData(true)] + public async Task TaskScheduler_AllCodeExecutedOnCorrectScheduler_Sync(bool defaultScheduler) + { + TaskScheduler scheduler = defaultScheduler ? + TaskScheduler.Default : + new ConcurrentExclusiveSchedulerPair().ConcurrentScheduler; + + TaskScheduler otherScheduler = new ConcurrentExclusiveSchedulerPair().ConcurrentScheduler; + + IEnumerable Iterate() + { + Assert.Same(scheduler, TaskScheduler.Current); + for (int i = 1; i <= 100; i++) + { + yield return i; + Assert.Same(scheduler, TaskScheduler.Current); + } + } + + var cq = new ConcurrentQueue(); + + await Parallel.ForEachAsync(Iterate(), new ParallelOptions { TaskScheduler = scheduler }, async (item, cancellationToken) => + { + Assert.Same(scheduler, TaskScheduler.Current); + await Task.Yield(); + cq.Enqueue(item); + + if (item % 10 == 0) + { + await new SwitchTo(otherScheduler); + } + }); + + Assert.Equal(Enumerable.Range(1, 100), cq.OrderBy(i => i)); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false)] + [InlineData(true)] + public async Task TaskScheduler_AllCodeExecutedOnCorrectScheduler_Async(bool defaultScheduler) + { + TaskScheduler scheduler = defaultScheduler ? + TaskScheduler.Default : + new ConcurrentExclusiveSchedulerPair().ConcurrentScheduler; + + TaskScheduler otherScheduler = new ConcurrentExclusiveSchedulerPair().ConcurrentScheduler; + + async IAsyncEnumerable Iterate() + { + Assert.Same(scheduler, TaskScheduler.Current); + for (int i = 1; i <= 100; i++) + { + await Task.Yield(); + yield return i; + Assert.Same(scheduler, TaskScheduler.Current); + } + } + + var cq = new ConcurrentQueue(); + + await Parallel.ForEachAsync(Iterate(), new ParallelOptions { TaskScheduler = scheduler }, async (item, cancellationToken) => + { + Assert.Same(scheduler, TaskScheduler.Current); + await Task.Yield(); + cq.Enqueue(item); + + if (item % 10 == 0) + { + await new SwitchTo(otherScheduler); + } + }); + + Assert.Equal(Enumerable.Range(1, 100), cq.OrderBy(i => i)); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_CancelsIterationAndReturnsCanceledTask_Sync() + { + static async IAsyncEnumerable Infinite() + { + int i = 0; + while (true) + { + await Task.Yield(); + yield return i++; + } + } + + using var cts = new CancellationTokenSource(10); + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Parallel.ForEachAsync(Infinite(), cts.Token, async (item, cancellationToken) => + { + await Task.Yield(); + })); + Assert.Equal(cts.Token, oce.CancellationToken); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_CancelsIterationAndReturnsCanceledTask_Async() + { + static async IAsyncEnumerable InfiniteAsync() + { + int i = 0; + while (true) + { + await Task.Yield(); + yield return i++; + } + } + + using var cts = new CancellationTokenSource(10); + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Parallel.ForEachAsync(InfiniteAsync(), cts.Token, async (item, cancellationToken) => + { + await Task.Yield(); + })); + Assert.Equal(cts.Token, oce.CancellationToken); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_CorrectTokenPassedToAsyncEnumerator() + { + static async IAsyncEnumerable YieldTokenAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.Yield(); + yield return cancellationToken; + } + + await Parallel.ForEachAsync(YieldTokenAsync(default), (item, cancellationToken) => + { + Assert.Equal(cancellationToken, item); + return default; + }); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_SameTokenPassedToEveryInvocation_Sync() + { + var cq = new ConcurrentQueue(); + + await Parallel.ForEachAsync(Enumerable.Range(1, 100), async (item, cancellationToken) => + { + cq.Enqueue(cancellationToken); + await Task.Yield(); + }); + + Assert.Equal(100, cq.Count); + Assert.Equal(1, cq.Distinct().Count()); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_SameTokenPassedToEveryInvocation_Async() + { + var cq = new ConcurrentQueue(); + + await Parallel.ForEachAsync(EnumerableRangeAsync(1, 100), async (item, cancellationToken) => + { + cq.Enqueue(cancellationToken); + await Task.Yield(); + }); + + Assert.Equal(100, cq.Count); + Assert.Equal(1, cq.Distinct().Count()); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_HasPriorityOverExceptions_Sync() + { + static IEnumerable Iterate() + { + int counter = 0; + while (true) yield return counter++; + } + + var tcs = new TaskCompletionSource(); + var cts = new CancellationTokenSource(); + + Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = 2 }, async (item, cancellationToken) => + { + if (item == 0) + { + await tcs.Task; + cts.Cancel(); + throw new FormatException(); + } + else + { + tcs.TrySetResult(); + await Task.Yield(); + } + }); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => t); + Assert.Equal(cts.Token, oce.CancellationToken); + Assert.True(t.IsCanceled); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Cancellation_HasPriorityOverExceptions_Async() + { + static async IAsyncEnumerable Iterate() + { + int counter = 0; + while (true) + { + await Task.Yield(); + yield return counter++; + } + } + + var tcs = new TaskCompletionSource(); + var cts = new CancellationTokenSource(); + + Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = 2 }, async (item, cancellationToken) => + { + if (item == 0) + { + await tcs.Task; + cts.Cancel(); + throw new FormatException(); + } + else + { + tcs.TrySetResult(); + await Task.Yield(); + } + }); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => t); + Assert.Equal(cts.Token, oce.CancellationToken); + Assert.True(t.IsCanceled); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void Exception_FromGetEnumerator_Sync() + { + Task t = Parallel.ForEachAsync((IEnumerable)new ThrowsFromGetEnumerator(), (item, cancellationToken) => default); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + Assert.IsType(t.Exception.InnerException); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void Exception_FromGetEnumerator_Async() + { + Task t = Parallel.ForEachAsync((IAsyncEnumerable)new ThrowsFromGetEnumerator(), (item, cancellationToken) => default); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + Assert.IsType(t.Exception.InnerException); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void Exception_NullFromGetEnumerator_Sync() + { + Task t = Parallel.ForEachAsync((IEnumerable)new ReturnsNullFromGetEnumerator(), (item, cancellationToken) => default); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + Assert.IsType(t.Exception.InnerException); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public void Exception_NullFromGetEnumerator_Async() + { + Task t = Parallel.ForEachAsync((IAsyncEnumerable)new ReturnsNullFromGetEnumerator(), (item, cancellationToken) => default); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + Assert.IsType(t.Exception.InnerException); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_FromMoveNext_Sync() + { + static IEnumerable Iterate() + { + for (int i = 0; i < 10; i++) + { + if (i == 4) + { + throw new FormatException(); + } + yield return i; + } + } + + Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default); + await Assert.ThrowsAsync(() => t); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_FromMoveNext_Async() + { + static async IAsyncEnumerable Iterate() + { + for (int i = 0; i < 10; i++) + { + await Task.Yield(); + if (i == 4) + { + throw new FormatException(); + } + yield return i; + } + } + + Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default); + await Assert.ThrowsAsync(() => t); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_FromLoopBody_Sync() + { + static IEnumerable Iterate() + { + yield return 1; + yield return 2; + } + + var barrier = new Barrier(2); + Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { MaxDegreeOfParallelism = barrier.ParticipantCount }, (item, cancellationToken) => + { + barrier.SignalAndWait(); + throw item switch + { + 1 => new FormatException(), + 2 => new InvalidTimeZoneException(), + _ => new Exception() + }; + }); + await Assert.ThrowsAnyAsync(() => t); + Assert.True(t.IsFaulted); + Assert.Equal(2, t.Exception.InnerExceptions.Count); + Assert.Contains(t.Exception.InnerExceptions, e => e is FormatException); + Assert.Contains(t.Exception.InnerExceptions, e => e is InvalidTimeZoneException); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_FromLoopBody_Async() + { + static async IAsyncEnumerable Iterate() + { + await Task.Yield(); + yield return 1; + yield return 2; + yield return 3; + yield return 4; + } + + int remaining = 4; + var tcs = new TaskCompletionSource(); + + Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { MaxDegreeOfParallelism = 4 }, async (item, cancellationToken) => + { + if (Interlocked.Decrement(ref remaining) == 0) + { + tcs.SetResult(); + } + await tcs.Task; + + throw item switch + { + 1 => new FormatException(), + 2 => new InvalidTimeZoneException(), + 3 => new ArithmeticException(), + 4 => new DivideByZeroException(), + _ => new Exception() + }; + }); + await Assert.ThrowsAnyAsync(() => t); + Assert.True(t.IsFaulted); + Assert.Equal(4, t.Exception.InnerExceptions.Count); + Assert.Contains(t.Exception.InnerExceptions, e => e is FormatException); + Assert.Contains(t.Exception.InnerExceptions, e => e is InvalidTimeZoneException); + Assert.Contains(t.Exception.InnerExceptions, e => e is ArithmeticException); + Assert.Contains(t.Exception.InnerExceptions, e => e is DivideByZeroException); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_FromDispose_Sync() + { + Task t = Parallel.ForEachAsync((IEnumerable)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default); + await Assert.ThrowsAsync(() => t); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_FromDispose_Async() + { + Task t = Parallel.ForEachAsync((IAsyncEnumerable)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default); + await Assert.ThrowsAsync(() => t); + Assert.True(t.IsFaulted); + Assert.Equal(1, t.Exception.InnerExceptions.Count); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_ImplicitlyCancelsOtherWorkers_Sync() + { + static IEnumerable Iterate() + { + int i = 0; + while (true) + { + yield return i++; + } + } + + await Assert.ThrowsAsync(() => Parallel.ForEachAsync(Iterate(), async (item, cancellationToken) => + { + await Task.Yield(); + if (item == 1000) + { + throw new Exception(); + } + })); + + await Assert.ThrowsAsync(() => Parallel.ForEachAsync(Iterate(), new ParallelOptions { MaxDegreeOfParallelism = 2 }, async (item, cancellationToken) => + { + if (item == 0) + { + throw new FormatException(); + } + else + { + Assert.Equal(1, item); + var tcs = new TaskCompletionSource(); + cancellationToken.Register(() => tcs.SetResult()); + await tcs.Task; + } + })); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public async Task Exception_ImplicitlyCancelsOtherWorkers_Async() + { + static async IAsyncEnumerable Iterate() + { + int i = 0; + while (true) + { + await Task.Yield(); + yield return i++; + } + } + + await Assert.ThrowsAsync(() => Parallel.ForEachAsync(Iterate(), async (item, cancellationToken) => + { + await Task.Yield(); + if (item == 1000) + { + throw new Exception(); + } + })); + + await Assert.ThrowsAsync(() => Parallel.ForEachAsync(Iterate(), new ParallelOptions { MaxDegreeOfParallelism = 2 }, async (item, cancellationToken) => + { + if (item == 0) + { + throw new FormatException(); + } + else + { + Assert.Equal(1, item); + var tcs = new TaskCompletionSource(); + cancellationToken.Register(() => tcs.SetResult()); + await tcs.Task; + } + })); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false)] + [InlineData(true)] + public async Task ExecutionContext_FlowsToWorkerBodies_Sync(bool defaultScheduler) + { + TaskScheduler scheduler = defaultScheduler ? + TaskScheduler.Default : + new ConcurrentExclusiveSchedulerPair().ConcurrentScheduler; + + static IEnumerable Iterate() + { + for (int i = 0; i < 100; i++) + { + yield return i; + } + } + + var al = new AsyncLocal(); + al.Value = 42; + await Parallel.ForEachAsync(Iterate(), async (item, cancellationToken) => + { + await Task.Yield(); + Assert.Equal(42, al.Value); + }); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task ExecutionContext_FlowsToWorkerBodies_Async(bool defaultScheduler, bool flowContext) + { + TaskScheduler scheduler = defaultScheduler ? + TaskScheduler.Default : + new ConcurrentExclusiveSchedulerPair().ConcurrentScheduler; + + static async IAsyncEnumerable Iterate() + { + for (int i = 0; i < 100; i++) + { + await Task.Yield(); + yield return i; + } + } + + var al = new AsyncLocal(); + al.Value = 42; + + if (!flowContext) + { + ExecutionContext.SuppressFlow(); + } + + Task t = Parallel.ForEachAsync(Iterate(), async (item, cancellationToken) => + { + await Task.Yield(); + Assert.Equal(flowContext ? 42 : 0, al.Value); + }); + + if (!flowContext) + { + ExecutionContext.RestoreFlow(); + } + + await t; + } + + private static async IAsyncEnumerable EnumerableRangeAsync(int start, int count, bool yield = true) + { + for (int i = start; i < start + count; i++) + { + if (yield) + { + await Task.Yield(); + } + + yield return i; + } + } + + private sealed class ThrowsFromGetEnumerator : IAsyncEnumerable, IEnumerable + { + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => throw new DivideByZeroException(); + public IEnumerator GetEnumerator() => throw new FormatException(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + private sealed class ReturnsNullFromGetEnumerator : IAsyncEnumerable, IEnumerable + { + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => null; + public IEnumerator GetEnumerator() => null; + IEnumerator IEnumerable.GetEnumerator() => null; + } + + private sealed class ThrowsExceptionFromDispose : IAsyncEnumerable, IEnumerable, IAsyncEnumerator, IEnumerator + { + public int Current => throw new NotImplementedException(); + object IEnumerator.Current => throw new NotImplementedException(); + + public void Dispose() => throw new FormatException(); + public ValueTask DisposeAsync() => throw new DivideByZeroException(); + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => this; + public IEnumerator GetEnumerator() => this; + IEnumerator IEnumerable.GetEnumerator() => this; + + public bool MoveNext() => false; + public ValueTask MoveNextAsync() => new ValueTask(false); + + public void Reset() => throw new NotImplementedException(); + } + + private sealed class SwitchTo : INotifyCompletion + { + private readonly TaskScheduler _scheduler; + + public SwitchTo(TaskScheduler scheduler) => _scheduler = scheduler; + + public SwitchTo GetAwaiter() => this; + public bool IsCompleted => false; + public void GetResult() { } + public void OnCompleted(Action continuation) => Task.Factory.StartNew(continuation, CancellationToken.None, TaskCreationOptions.None, _scheduler); + } + + private sealed class MaxConcurrencyLevelPassthroughTaskScheduler : TaskScheduler + { + public MaxConcurrencyLevelPassthroughTaskScheduler(int maximumConcurrencyLevel) => + MaximumConcurrencyLevel = maximumConcurrencyLevel; + + protected override IEnumerable GetScheduledTasks() => Array.Empty(); + protected override void QueueTask(Task task) => ThreadPool.QueueUserWorkItem(_ => TryExecuteTask(task)); + protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) => TryExecuteTask(task); + + public override int MaximumConcurrencyLevel { get; } + } + } +} diff --git a/src/libraries/System.Threading.Tasks.Parallel/tests/System.Threading.Tasks.Parallel.Tests.csproj b/src/libraries/System.Threading.Tasks.Parallel/tests/System.Threading.Tasks.Parallel.Tests.csproj index b758ad3f2f258..c26f5f363a042 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/tests/System.Threading.Tasks.Parallel.Tests.csproj +++ b/src/libraries/System.Threading.Tasks.Parallel/tests/System.Threading.Tasks.Parallel.Tests.csproj @@ -7,13 +7,13 @@ - + + @@ -24,7 +24,6 @@ - +