diff --git a/src/bunit.core/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensions.cs b/src/bunit.core/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensions.cs index 912ee641c..ed6fedd6e 100644 --- a/src/bunit.core/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensions.cs +++ b/src/bunit.core/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensions.cs @@ -1,4 +1,5 @@ using System; +using System.Runtime.ExceptionServices; using Bunit.Extensions.WaitForHelpers; namespace Bunit @@ -22,13 +23,21 @@ public static class RenderedFragmentWaitForHelperExtensions public static void WaitForState(this IRenderedFragmentBase renderedFragment, Func statePredicate, TimeSpan? timeout = null) { using var waiter = new WaitForStateHelper(renderedFragment, statePredicate, timeout); + try { waiter.WaitTask.GetAwaiter().GetResult(); } - catch (AggregateException e) when (e.InnerException is not null) + catch (Exception e) { - throw e.InnerException; + if (e is AggregateException aggregateException && aggregateException.InnerExceptions.Count == 1) + { + ExceptionDispatchInfo.Capture(aggregateException.InnerExceptions[0]).Throw(); + } + else + { + ExceptionDispatchInfo.Capture(e).Throw(); + } } } @@ -45,13 +54,21 @@ public static void WaitForState(this IRenderedFragmentBase renderedFragment, Fun public static void WaitForAssertion(this IRenderedFragmentBase renderedFragment, Action assertion, TimeSpan? timeout = null) { using var waiter = new WaitForAssertionHelper(renderedFragment, assertion, timeout); + try { waiter.WaitTask.GetAwaiter().GetResult(); } - catch (AggregateException e) when (e.InnerException is not null) + catch (Exception e) { - throw e.InnerException; + if (e is AggregateException aggregateException && aggregateException.InnerExceptions.Count == 1) + { + ExceptionDispatchInfo.Capture(aggregateException.InnerExceptions[0]).Throw(); + } + else + { + ExceptionDispatchInfo.Capture(e).Throw(); + } } } } diff --git a/src/bunit.core/Extensions/WaitForHelpers/WaitForHelper.cs b/src/bunit.core/Extensions/WaitForHelpers/WaitForHelper.cs index b7c661f37..01e020b89 100644 --- a/src/bunit.core/Extensions/WaitForHelpers/WaitForHelper.cs +++ b/src/bunit.core/Extensions/WaitForHelpers/WaitForHelper.cs @@ -1,6 +1,8 @@ using System; using System.Threading; using System.Threading.Tasks; +using Bunit.Rendering; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Bunit.Extensions.WaitForHelpers @@ -13,7 +15,7 @@ public abstract class WaitForHelper : IDisposable { private readonly object lockObject = new(); private readonly Timer timer; - private readonly TaskCompletionSource completionSouce; + private readonly TaskCompletionSource checkPassedCompletionSouce; private readonly Func completeChecker; private readonly IRenderedFragmentBase renderedFragment; private readonly ILogger logger; @@ -40,7 +42,7 @@ public abstract class WaitForHelper : IDisposable /// Gets the task that will complete successfully if the check passed before the timeout was reached. /// The task will complete with an exception if the timeout was reached without the check passing. /// - public Task WaitTask => completionSouce.Task; + public Task WaitTask { get; } /// /// Initializes a new instance of the class. @@ -50,13 +52,25 @@ protected WaitForHelper(IRenderedFragmentBase renderedFragment, Func compl this.renderedFragment = renderedFragment ?? throw new ArgumentNullException(nameof(renderedFragment)); this.completeChecker = completeChecker ?? throw new ArgumentNullException(nameof(completeChecker)); logger = renderedFragment.Services.CreateLogger(); - completionSouce = new TaskCompletionSource(); + + var renderer = renderedFragment.Services.GetRequiredService(); + var renderException = renderer + .UnhandledException + .ContinueWith(x => Task.FromException(x.Result), CancellationToken.None, TaskContinuationOptions.OnlyOnRanToCompletion, TaskScheduler.Current) + .Unwrap(); + + checkPassedCompletionSouce = new TaskCompletionSource(); + WaitTask = Task.WhenAny(checkPassedCompletionSouce.Task, renderException).Unwrap(); + timer = new Timer(OnTimeout, this, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan); - OnAfterRender(this, EventArgs.Empty); - this.renderedFragment.OnAfterRender += OnAfterRender; - OnAfterRender(this, EventArgs.Empty); - StartTimer(timeout); + if (!WaitTask.IsCompleted) + { + OnAfterRender(this, EventArgs.Empty); + this.renderedFragment.OnAfterRender += OnAfterRender; + OnAfterRender(this, EventArgs.Empty); + StartTimer(timeout); + } } private void StartTimer(TimeSpan? timeout) @@ -88,7 +102,7 @@ private void OnAfterRender(object? sender, EventArgs args) logger.LogDebug(new EventId(1, nameof(OnAfterRender)), $"Checking the wait condition for component {renderedFragment.ComponentId}"); if (completeChecker()) { - completionSouce.TrySetResult(null); + checkPassedCompletionSouce.TrySetResult(null); logger.LogDebug(new EventId(2, nameof(OnAfterRender)), $"The check completed successfully for component {renderedFragment.ComponentId}"); Dispose(); } @@ -104,7 +118,7 @@ private void OnAfterRender(object? sender, EventArgs args) if (StopWaitingOnCheckException) { - completionSouce.TrySetException(new WaitForFailedException(CheckThrowErrorMessage, capturedException)); + checkPassedCompletionSouce.TrySetException(new WaitForFailedException(CheckThrowErrorMessage, capturedException)); Dispose(); } } @@ -123,7 +137,7 @@ private void OnTimeout(object? state) logger.LogDebug(new EventId(5, nameof(OnTimeout)), $"The wait for helper for component {renderedFragment.ComponentId} timed out"); - completionSouce.TrySetException(new WaitForFailedException(TimeoutErrorMessage, capturedException)); + checkPassedCompletionSouce.TrySetException(new WaitForFailedException(TimeoutErrorMessage, capturedException)); Dispose(); } @@ -160,7 +174,7 @@ protected virtual void Dispose(bool disposing) isDisposed = true; renderedFragment.OnAfterRender -= OnAfterRender; timer.Dispose(); - completionSouce.TrySetCanceled(); + checkPassedCompletionSouce.TrySetCanceled(); logger.LogDebug(new EventId(6, nameof(Dispose)), $"The state wait helper for component {renderedFragment.ComponentId} disposed"); } } diff --git a/tests/bunit.core.tests/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensionsTest.cs b/tests/bunit.core.tests/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensionsTest.cs index af5ff60c9..8045a0d39 100644 --- a/tests/bunit.core.tests/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensionsTest.cs +++ b/tests/bunit.core.tests/Extensions/WaitForHelpers/RenderedFragmentWaitForHelperExtensionsTest.cs @@ -1,117 +1,148 @@ using System; +using System.Threading.Tasks; using AngleSharp.Dom; using Bunit.TestAssets.SampleComponents; +using Microsoft.AspNetCore.Components; using Shouldly; using Xunit; using Xunit.Abstractions; namespace Bunit.Extensions.WaitForHelpers { - public class RenderedFragmentWaitForHelperExtensionsTest : TestContext - { - public RenderedFragmentWaitForHelperExtensionsTest(ITestOutputHelper testOutput) - { - Services.AddXunitLogger(testOutput); - } - - [Fact(DisplayName = "WaitForAssertion can wait for multiple renders and changes to occur")] - public void Test110() - { - // Initial state is stopped - var cut = RenderComponent(); - var stateElement = cut.Find("#state"); - stateElement.TextContent.ShouldBe("Stopped"); - - // Clicking 'tick' changes the state, and starts a task - cut.Find("#tick").Click(); - cut.Find("#state").TextContent.ShouldBe("Started"); - - // Clicking 'tock' completes the task, which updates the state - // This click causes two renders, thus something is needed to await here. - cut.Find("#tock").Click(); - cut.WaitForAssertion(() => cut.Find("#state").TextContent.ShouldBe("Stopped")); - } - - [Fact(DisplayName = "WaitForAssertion throws assertion exception after timeout")] - public void Test011() - { - var cut = RenderComponent(); - - var expected = Should.Throw(() => - cut.WaitForAssertion(() => cut.Markup.ShouldBeEmpty(), TimeSpan.FromMilliseconds(10))); - - expected.Message.ShouldBe(WaitForAssertionHelper.TimeoutMessage); - expected.InnerException.ShouldBeOfType(); - } - - [Fact(DisplayName = "WaitForState throws exception after timeout")] - public void Test012() - { - var cut = RenderComponent(); - - var expected = Should.Throw(() => - cut.WaitForState(() => string.IsNullOrEmpty(cut.Markup), TimeSpan.FromMilliseconds(100))); - - expected.Message.ShouldBe(WaitForStateHelper.TimeoutBeforePassMessage); - } - - [Fact(DisplayName = "WaitForState throws exception if statePredicate throws on a later render")] - public void Test013() - { - const string expectedInnerMessage = "INNER MESSAGE"; - var cut = RenderComponent(); - cut.Find("#tick").Click(); - cut.Find("#tock").Click(); - - var expected = Should.Throw(() => - cut.WaitForState(() => - { - if (cut.Find("#state").TextContent == "Stopped") - throw new InvalidOperationException(expectedInnerMessage); - return false; - })); - - expected.Message.ShouldBe(WaitForStateHelper.ExceptionInPredicateMessage); - expected.InnerException.ShouldBeOfType() - .Message.ShouldBe(expectedInnerMessage); - } - - [Fact(DisplayName = "WaitForState can wait for multiple renders and changes to occur")] - public void Test100() - { - // Initial state is stopped - var cut = RenderComponent(); - - // Clicking 'tick' changes the state, and starts a task - cut.Find("#tick").Click(); - cut.Find("#state").TextContent.ShouldBe("Started"); - - // Clicking 'tock' completes the task, which updates the state - // This click causes two renders, thus something is needed to await here. - cut.Find("#tock").Click(); - cut.WaitForState(() => - { - var elm = cut.Nodes.QuerySelector("#state"); - return elm?.TextContent == "Stopped"; - }); - } - - [Fact(DisplayName = "WaitForState can detect async changes to properties in the CUT")] - public void Test200() - { - var cut = RenderComponent(); - cut.Instance.Counter.ShouldBe(0); - - // Clicking 'tick' changes the counter, and starts a task - cut.Find("#tick").Click(); - cut.Instance.Counter.ShouldBe(1); - - // Clicking 'tock' completes the task, which updates the counter - // This click causes two renders, thus something is needed to await here. - cut.Find("#tock").Click(); - cut.WaitForState(() => cut.Instance.Counter == 2); - - cut.Instance.Counter.ShouldBe(2); - } - } + public class RenderedFragmentWaitForHelperExtensionsTest : TestContext + { + public RenderedFragmentWaitForHelperExtensionsTest(ITestOutputHelper testOutput) + { + Services.AddXunitLogger(testOutput); + } + + [Fact(DisplayName = "WaitForAssertion can wait for multiple renders and changes to occur")] + public void Test110() + { + // Initial state is stopped + var cut = RenderComponent(); + var stateElement = cut.Find("#state"); + stateElement.TextContent.ShouldBe("Stopped"); + + // Clicking 'tick' changes the state, and starts a task + cut.Find("#tick").Click(); + cut.Find("#state").TextContent.ShouldBe("Started"); + + // Clicking 'tock' completes the task, which updates the state + // This click causes two renders, thus something is needed to await here. + cut.Find("#tock").Click(); + cut.WaitForAssertion(() => cut.Find("#state").TextContent.ShouldBe("Stopped")); + } + + [Fact(DisplayName = "WaitForAssertion throws assertion exception after timeout")] + public void Test011() + { + var cut = RenderComponent(); + + var expected = Should.Throw(() => + cut.WaitForAssertion(() => cut.Markup.ShouldBeEmpty(), TimeSpan.FromMilliseconds(10))); + + expected.Message.ShouldBe(WaitForAssertionHelper.TimeoutMessage); + expected.InnerException.ShouldBeOfType(); + } + + [Fact(DisplayName = "WaitForState throws exception after timeout")] + public void Test012() + { + var cut = RenderComponent(); + + var expected = Should.Throw(() => + cut.WaitForState(() => string.IsNullOrEmpty(cut.Markup), TimeSpan.FromMilliseconds(100))); + + expected.Message.ShouldBe(WaitForStateHelper.TimeoutBeforePassMessage); + } + + [Fact(DisplayName = "WaitForState throws exception if statePredicate throws on a later render")] + public void Test013() + { + const string expectedInnerMessage = "INNER MESSAGE"; + var cut = RenderComponent(); + cut.Find("#tick").Click(); + cut.Find("#tock").Click(); + + var expected = Should.Throw(() => + cut.WaitForState(() => + { + if (cut.Find("#state").TextContent == "Stopped") + throw new InvalidOperationException(expectedInnerMessage); + return false; + })); + + expected.Message.ShouldBe(WaitForStateHelper.ExceptionInPredicateMessage); + expected.InnerException.ShouldBeOfType() + .Message.ShouldBe(expectedInnerMessage); + } + + [Fact(DisplayName = "WaitForState can wait for multiple renders and changes to occur")] + public void Test100() + { + // Initial state is stopped + var cut = RenderComponent(); + + // Clicking 'tick' changes the state, and starts a task + cut.Find("#tick").Click(); + cut.Find("#state").TextContent.ShouldBe("Started"); + + // Clicking 'tock' completes the task, which updates the state + // This click causes two renders, thus something is needed to await here. + cut.Find("#tock").Click(); + cut.WaitForState(() => + { + var elm = cut.Nodes.QuerySelector("#state"); + return elm?.TextContent == "Stopped"; + }); + } + + [Fact(DisplayName = "WaitForState can detect async changes to properties in the CUT")] + public void Test200() + { + var cut = RenderComponent(); + cut.Instance.Counter.ShouldBe(0); + + // Clicking 'tick' changes the counter, and starts a task + cut.Find("#tick").Click(); + cut.Instance.Counter.ShouldBe(1); + + // Clicking 'tock' completes the task, which updates the counter + // This click causes two renders, thus something is needed to await here. + cut.Find("#tock").Click(); + cut.WaitForState(() => cut.Instance.Counter == 2); + + cut.Instance.Counter.ShouldBe(2); + } + + [Fact(DisplayName = "WaitForAssertion rethrows unhandled exception from a components async operation's methods")] + public void Test300() + { + var cut = RenderComponent(); + + Should.Throw( + () => cut.WaitForAssertion(() => false.ShouldBeTrue())); + } + + [Fact(DisplayName = "WaitForState rethrows unhandled exception from components async operation's methods")] + public void Test301() + { + var cut = RenderComponent(); + + Should.Throw( + () => cut.WaitForState(() => false)); + } + + internal class ThrowsAfterAsyncOperation : ComponentBase + { + protected override async Task OnInitializedAsync() + { + await Task.Delay(1); + throw new ThrowsAfterAsyncOperationException(); + } + + internal sealed class ThrowsAfterAsyncOperationException : Exception { } + } + } }