diff --git a/src/libraries/Common/src/System/TimeProvider.cs b/src/libraries/Common/src/System/TimeProvider.cs index 7d056754dc655..d1489826eecc4 100644 --- a/src/libraries/Common/src/System/TimeProvider.cs +++ b/src/libraries/Common/src/System/TimeProvider.cs @@ -73,7 +73,7 @@ public DateTimeOffset GetLocalNow() /// /// The default implementation returns . /// - public virtual TimeZoneInfo LocalTimeZone { get => TimeZoneInfo.Local; } + public virtual TimeZoneInfo LocalTimeZone => TimeZoneInfo.Local; /// /// Gets the frequency of of high-frequency value per second. @@ -81,7 +81,7 @@ public DateTimeOffset GetLocalNow() /// /// The default implementation returns . For a given TimeProvider instance, the value must be idempotent and remain unchanged. /// - public virtual long TimestampFrequency { get => Stopwatch.Frequency; } + public virtual long TimestampFrequency => Stopwatch.Frequency; /// /// Gets the current high-frequency value designed to measure small time intervals with high accuracy in the timer mechanism. @@ -187,14 +187,27 @@ public SystemTimeProviderTimer(TimeSpan dueTime, TimeSpan period, TimerCallback #if SYSTEM_PRIVATE_CORELIB _timer = new TimerQueueTimer(callback, state, duration, periodTime, flowExecutionContext: true); #else - // We want to ensure the timer we create will be tracked as long as it is scheduled. - // To do that, we call the constructor which track only the callback which will make the time to be tracked by the scheduler - // then we call Change on the timer to set the desired duration and period. - _timer = new Timer(_ => callback(state)); - _timer.Change(duration, periodTime); + // We need to ensure the timer roots itself. Timer created with a duration and period argument + // only roots the state object, so to root the timer we need the state object to reference the + // timer recursively. + var timerState = new TimerState(callback, state); + timerState.Timer = _timer = new Timer(static s => + { + TimerState ts = (TimerState)s!; + ts.Callback(ts.State); + }, timerState, duration, periodTime); #endif // SYSTEM_PRIVATE_CORELIB } +#if !SYSTEM_PRIVATE_CORELIB + private sealed class TimerState(TimerCallback callback, object? state) + { + public TimerCallback Callback { get; } = callback; + public object? State { get; } = state; + public Timer? Timer { get; set; } + } +#endif + public bool Change(TimeSpan dueTime, TimeSpan period) { (uint duration, uint periodTime) = CheckAndGetValues(dueTime, period); diff --git a/src/libraries/Common/tests/System/TimeProviderTests.cs b/src/libraries/Common/tests/System/TimeProviderTests.cs index 21d477e26bd4d..b3fe700165e2f 100644 --- a/src/libraries/Common/tests/System/TimeProviderTests.cs +++ b/src/libraries/Common/tests/System/TimeProviderTests.cs @@ -409,6 +409,36 @@ public static void NegativeTests() #endif // !NETFRAMEWORK } +#if TESTEXTENSIONS + [Fact] + public static void InvokeCallbackFromCreateTimer() + { + TimeProvider p = new InvokeCallbackCreateTimerProvider(); + + CancellationTokenSource cts = p.CreateCancellationTokenSource(TimeSpan.FromSeconds(0)); + Assert.True(cts.IsCancellationRequested); + + Task t = p.Delay(TimeSpan.FromSeconds(0)); + Assert.True(t.IsCompleted); + + t = new TaskCompletionSource().Task.WaitAsync(TimeSpan.FromSeconds(0), p); + Assert.True(t.IsFaulted); + } + + class InvokeCallbackCreateTimerProvider : TimeProvider + { + public override ITimer CreateTimer(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period) + { + ITimer t = base.CreateTimer(callback, state, dueTime, period); + if (dueTime != Timeout.InfiniteTimeSpan) + { + callback(state); + } + return t; + } + } +#endif + class TimerState { public TimerState() diff --git a/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs b/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs index c8354d7f7c767..9f8ba325f1d14 100644 --- a/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs +++ b/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs @@ -15,15 +15,25 @@ public static class TimeProviderTaskExtensions #if !NET8_0_OR_GREATER private sealed class DelayState : TaskCompletionSource { - public DelayState() : base(TaskCreationOptions.RunContinuationsAsynchronously) {} - public ITimer Timer { get; set; } + public DelayState(CancellationToken cancellationToken) : base(TaskCreationOptions.RunContinuationsAsynchronously) + { + CancellationToken = cancellationToken; + } + + public ITimer? Timer { get; set; } + public CancellationToken CancellationToken { get; } public CancellationTokenRegistration Registration { get; set; } } private sealed class WaitAsyncState : TaskCompletionSource { - public WaitAsyncState() : base(TaskCreationOptions.RunContinuationsAsynchronously) { } + public WaitAsyncState(CancellationToken cancellationToken) : base(TaskCreationOptions.RunContinuationsAsynchronously) + { + CancellationToken = cancellationToken; + } + public readonly CancellationTokenSource ContinuationCancellation = new CancellationTokenSource(); + public CancellationToken CancellationToken { get; } public CancellationTokenRegistration Registration; public ITimer? Timer; } @@ -66,22 +76,22 @@ public static Task Delay(this TimeProvider timeProvider, TimeSpan delay, Cancell return Task.FromCanceled(cancellationToken); } - DelayState state = new(); + DelayState state = new(cancellationToken); - state.Timer = timeProvider.CreateTimer(delayState => + state.Timer = timeProvider.CreateTimer(static delayState => { DelayState s = (DelayState)delayState!; s.TrySetResult(true); s.Registration.Dispose(); - s?.Timer.Dispose(); + s.Timer?.Dispose(); }, state, delay, Timeout.InfiniteTimeSpan); - state.Registration = cancellationToken.Register(delayState => + state.Registration = cancellationToken.Register(static delayState => { DelayState s = (DelayState)delayState!; - s.TrySetCanceled(cancellationToken); + s.TrySetCanceled(s.CancellationToken); s.Registration.Dispose(); - s?.Timer.Dispose(); + s.Timer?.Dispose(); }, state); // There are race conditions where the timer fires after we have attached the cancellation callback but before the @@ -153,7 +163,7 @@ public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider time return Task.FromCanceled(cancellationToken); } - var state = new WaitAsyncState(); + WaitAsyncState state = new(cancellationToken); state.Timer = timeProvider.CreateTimer(static s => { @@ -162,7 +172,7 @@ public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider time state.TrySetException(new TimeoutException()); state.Registration.Dispose(); - state.Timer!.Dispose(); + state.Timer?.Dispose(); state.ContinuationCancellation.Cancel(); }, state, timeout, Timeout.InfiniteTimeSpan); @@ -182,7 +192,7 @@ public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider time { var state = (WaitAsyncState)s!; - state.TrySetCanceled(); + state.TrySetCanceled(state.CancellationToken); state.Timer?.Dispose(); state.ContinuationCancellation.Cancel(); @@ -259,16 +269,16 @@ public static CancellationTokenSource CreateCancellationTokenSource(this TimePro var cts = new CancellationTokenSource(); - ITimer timer = timeProvider.CreateTimer(s => + ITimer timer = timeProvider.CreateTimer(static s => { try { - ((CancellationTokenSource)s).Cancel(); + ((CancellationTokenSource)s!).Cancel(); } catch (ObjectDisposedException) { } }, cts, delay, Timeout.InfiniteTimeSpan); - cts.Token.Register(t => ((ITimer)t).Dispose(), timer); + cts.Token.Register(static t => ((ITimer)t!).Dispose(), timer); return cts; #endif // NET8_0_OR_GREATER }