diff --git a/src/libraries/Common/src/System/TimeProvider.cs b/src/libraries/Common/src/System/TimeProvider.cs
index 7d056754dc655..d1489826eecc4 100644
--- a/src/libraries/Common/src/System/TimeProvider.cs
+++ b/src/libraries/Common/src/System/TimeProvider.cs
@@ -73,7 +73,7 @@ public DateTimeOffset GetLocalNow()
///
/// The default implementation returns .
///
- public virtual TimeZoneInfo LocalTimeZone { get => TimeZoneInfo.Local; }
+ public virtual TimeZoneInfo LocalTimeZone => TimeZoneInfo.Local;
///
/// Gets the frequency of of high-frequency value per second.
@@ -81,7 +81,7 @@ public DateTimeOffset GetLocalNow()
///
/// The default implementation returns . For a given TimeProvider instance, the value must be idempotent and remain unchanged.
///
- public virtual long TimestampFrequency { get => Stopwatch.Frequency; }
+ public virtual long TimestampFrequency => Stopwatch.Frequency;
///
/// Gets the current high-frequency value designed to measure small time intervals with high accuracy in the timer mechanism.
@@ -187,14 +187,27 @@ public SystemTimeProviderTimer(TimeSpan dueTime, TimeSpan period, TimerCallback
#if SYSTEM_PRIVATE_CORELIB
_timer = new TimerQueueTimer(callback, state, duration, periodTime, flowExecutionContext: true);
#else
- // We want to ensure the timer we create will be tracked as long as it is scheduled.
- // To do that, we call the constructor which track only the callback which will make the time to be tracked by the scheduler
- // then we call Change on the timer to set the desired duration and period.
- _timer = new Timer(_ => callback(state));
- _timer.Change(duration, periodTime);
+ // We need to ensure the timer roots itself. Timer created with a duration and period argument
+ // only roots the state object, so to root the timer we need the state object to reference the
+ // timer recursively.
+ var timerState = new TimerState(callback, state);
+ timerState.Timer = _timer = new Timer(static s =>
+ {
+ TimerState ts = (TimerState)s!;
+ ts.Callback(ts.State);
+ }, timerState, duration, periodTime);
#endif // SYSTEM_PRIVATE_CORELIB
}
+#if !SYSTEM_PRIVATE_CORELIB
+ private sealed class TimerState(TimerCallback callback, object? state)
+ {
+ public TimerCallback Callback { get; } = callback;
+ public object? State { get; } = state;
+ public Timer? Timer { get; set; }
+ }
+#endif
+
public bool Change(TimeSpan dueTime, TimeSpan period)
{
(uint duration, uint periodTime) = CheckAndGetValues(dueTime, period);
diff --git a/src/libraries/Common/tests/System/TimeProviderTests.cs b/src/libraries/Common/tests/System/TimeProviderTests.cs
index 21d477e26bd4d..b3fe700165e2f 100644
--- a/src/libraries/Common/tests/System/TimeProviderTests.cs
+++ b/src/libraries/Common/tests/System/TimeProviderTests.cs
@@ -409,6 +409,36 @@ public static void NegativeTests()
#endif // !NETFRAMEWORK
}
+#if TESTEXTENSIONS
+ [Fact]
+ public static void InvokeCallbackFromCreateTimer()
+ {
+ TimeProvider p = new InvokeCallbackCreateTimerProvider();
+
+ CancellationTokenSource cts = p.CreateCancellationTokenSource(TimeSpan.FromSeconds(0));
+ Assert.True(cts.IsCancellationRequested);
+
+ Task t = p.Delay(TimeSpan.FromSeconds(0));
+ Assert.True(t.IsCompleted);
+
+ t = new TaskCompletionSource().Task.WaitAsync(TimeSpan.FromSeconds(0), p);
+ Assert.True(t.IsFaulted);
+ }
+
+ class InvokeCallbackCreateTimerProvider : TimeProvider
+ {
+ public override ITimer CreateTimer(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period)
+ {
+ ITimer t = base.CreateTimer(callback, state, dueTime, period);
+ if (dueTime != Timeout.InfiniteTimeSpan)
+ {
+ callback(state);
+ }
+ return t;
+ }
+ }
+#endif
+
class TimerState
{
public TimerState()
diff --git a/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs b/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs
index c8354d7f7c767..9f8ba325f1d14 100644
--- a/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs
+++ b/src/libraries/Microsoft.Bcl.TimeProvider/src/System/Threading/Tasks/TimeProviderTaskExtensions.cs
@@ -15,15 +15,25 @@ public static class TimeProviderTaskExtensions
#if !NET8_0_OR_GREATER
private sealed class DelayState : TaskCompletionSource
{
- public DelayState() : base(TaskCreationOptions.RunContinuationsAsynchronously) {}
- public ITimer Timer { get; set; }
+ public DelayState(CancellationToken cancellationToken) : base(TaskCreationOptions.RunContinuationsAsynchronously)
+ {
+ CancellationToken = cancellationToken;
+ }
+
+ public ITimer? Timer { get; set; }
+ public CancellationToken CancellationToken { get; }
public CancellationTokenRegistration Registration { get; set; }
}
private sealed class WaitAsyncState : TaskCompletionSource
{
- public WaitAsyncState() : base(TaskCreationOptions.RunContinuationsAsynchronously) { }
+ public WaitAsyncState(CancellationToken cancellationToken) : base(TaskCreationOptions.RunContinuationsAsynchronously)
+ {
+ CancellationToken = cancellationToken;
+ }
+
public readonly CancellationTokenSource ContinuationCancellation = new CancellationTokenSource();
+ public CancellationToken CancellationToken { get; }
public CancellationTokenRegistration Registration;
public ITimer? Timer;
}
@@ -66,22 +76,22 @@ public static Task Delay(this TimeProvider timeProvider, TimeSpan delay, Cancell
return Task.FromCanceled(cancellationToken);
}
- DelayState state = new();
+ DelayState state = new(cancellationToken);
- state.Timer = timeProvider.CreateTimer(delayState =>
+ state.Timer = timeProvider.CreateTimer(static delayState =>
{
DelayState s = (DelayState)delayState!;
s.TrySetResult(true);
s.Registration.Dispose();
- s?.Timer.Dispose();
+ s.Timer?.Dispose();
}, state, delay, Timeout.InfiniteTimeSpan);
- state.Registration = cancellationToken.Register(delayState =>
+ state.Registration = cancellationToken.Register(static delayState =>
{
DelayState s = (DelayState)delayState!;
- s.TrySetCanceled(cancellationToken);
+ s.TrySetCanceled(s.CancellationToken);
s.Registration.Dispose();
- s?.Timer.Dispose();
+ s.Timer?.Dispose();
}, state);
// There are race conditions where the timer fires after we have attached the cancellation callback but before the
@@ -153,7 +163,7 @@ public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider time
return Task.FromCanceled(cancellationToken);
}
- var state = new WaitAsyncState();
+ WaitAsyncState state = new(cancellationToken);
state.Timer = timeProvider.CreateTimer(static s =>
{
@@ -162,7 +172,7 @@ public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider time
state.TrySetException(new TimeoutException());
state.Registration.Dispose();
- state.Timer!.Dispose();
+ state.Timer?.Dispose();
state.ContinuationCancellation.Cancel();
}, state, timeout, Timeout.InfiniteTimeSpan);
@@ -182,7 +192,7 @@ public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider time
{
var state = (WaitAsyncState)s!;
- state.TrySetCanceled();
+ state.TrySetCanceled(state.CancellationToken);
state.Timer?.Dispose();
state.ContinuationCancellation.Cancel();
@@ -259,16 +269,16 @@ public static CancellationTokenSource CreateCancellationTokenSource(this TimePro
var cts = new CancellationTokenSource();
- ITimer timer = timeProvider.CreateTimer(s =>
+ ITimer timer = timeProvider.CreateTimer(static s =>
{
try
{
- ((CancellationTokenSource)s).Cancel();
+ ((CancellationTokenSource)s!).Cancel();
}
catch (ObjectDisposedException) { }
}, cts, delay, Timeout.InfiniteTimeSpan);
- cts.Token.Register(t => ((ITimer)t).Dispose(), timer);
+ cts.Token.Register(static t => ((ITimer)t!).Dispose(), timer);
return cts;
#endif // NET8_0_OR_GREATER
}