diff --git a/cpp/mrc/src/public/coroutines/io_scheduler.cpp b/cpp/mrc/src/public/coroutines/io_scheduler.cpp index 5d55bc14f..7096ad160 100644 --- a/cpp/mrc/src/public/coroutines/io_scheduler.cpp +++ b/cpp/mrc/src/public/coroutines/io_scheduler.cpp @@ -179,42 +179,30 @@ auto IoScheduler::yield_for(std::chrono::milliseconds amount) -> mrc::coroutines } else { - // Yield/timeout tasks are considered live in the scheduler and must be accounted for. Note - // that if the user gives an invalid amount and schedule() is directly called it will account - // for the scheduled task there. - m_size.fetch_add(1, std::memory_order::release); - - // Yielding does not requiring setting the timer position on the poll info since - // it doesn't have a corresponding 'event' that can trigger, it always waits for - // the timeout to occur before resuming. - - detail::PollInfo pi{}; - add_timer_token(clock_t::now() + amount, pi); - co_await pi; - - m_size.fetch_sub(1, std::memory_order::release); + co_return co_await yield_until(coroutines::clock_t::now() + amount); } - co_return; } auto IoScheduler::yield_until(time_point_t time) -> mrc::coroutines::Task { - auto now = clock_t::now(); - // If the requested time is in the past (or now!) bail out! - if (time <= now) + if (time <= clock_t::now()) { co_await schedule(); } else { + // Yield/timeout tasks are considered live in the scheduler and must be accounted for. Note + // that if the user gives an invalid amount and schedule() is directly called it will account + // for the scheduled task there. m_size.fetch_add(1, std::memory_order::release); - auto amount = std::chrono::duration_cast(time - now); - - detail::PollInfo pi{}; - add_timer_token(now + amount, pi); - co_await pi; + // Yielding does not requiring setting the timer position on the poll info since + // it doesn't have a corresponding 'event' that can trigger, it always waits for + // the timeout to occur before resuming. + detail::PollInfo poll_info{}; + add_timer_token(time, poll_info); + co_await poll_info; m_size.fetch_sub(1, std::memory_order::release); } diff --git a/cpp/mrc/tests/coroutines/test_io_scheduler.cpp b/cpp/mrc/tests/coroutines/test_io_scheduler.cpp index 600f7b660..4aa7d84c4 100644 --- a/cpp/mrc/tests/coroutines/test_io_scheduler.cpp +++ b/cpp/mrc/tests/coroutines/test_io_scheduler.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -41,42 +42,69 @@ TEST_F(TestCoroIoScheduler, YieldFor) { auto scheduler = coroutines::IoScheduler::get_instance(); + static constexpr std::chrono::milliseconds Delay{10}; + auto task = [scheduler]() -> coroutines::Task<> { - co_await scheduler->yield_for(10ms); + co_await scheduler->yield_for(Delay); }; + auto start = coroutines::clock_t::now(); coroutines::sync_wait(task()); + auto stop = coroutines::clock_t::now(); + + ASSERT_GE(stop - start, Delay); } TEST_F(TestCoroIoScheduler, YieldUntil) { auto scheduler = coroutines::IoScheduler::get_instance(); - auto task = [scheduler]() -> coroutines::Task<> { - co_await scheduler->yield_until(coroutines::clock_t::now() + 10ms); + coroutines::clock_t::time_point target_time{}; + + auto task = [scheduler, &target_time]() -> coroutines::Task<> { + target_time = coroutines::clock_t::now() + 10ms; + co_await scheduler->yield_until(target_time); }; coroutines::sync_wait(task()); + + auto current_time = coroutines::clock_t::now(); + + ASSERT_GE(current_time, target_time); } TEST_F(TestCoroIoScheduler, Concurrent) { auto scheduler = coroutines::IoScheduler::get_instance(); + auto per_task_overhead = [&] { + static constexpr std::chrono::milliseconds SmallestDelay{1}; + auto start = coroutines::clock_t::now(); + coroutines::sync_wait([scheduler]() -> coroutines::Task<> { + co_await scheduler->yield_for(SmallestDelay); + }()); + auto stop = coroutines::clock_t::now(); + return (stop - start) - SmallestDelay; + }(); + + static constexpr std::chrono::milliseconds TaskDuration{10}; + auto task = [scheduler]() -> coroutines::Task<> { - co_await scheduler->yield_for(10ms); + co_await scheduler->yield_for(TaskDuration); }; auto start = coroutines::clock_t::now(); std::vector> tasks; - for (uint32_t i = 0; i < 1000; i++) + const uint32_t NumTasks{1'000}; + for (uint32_t i = 0; i < NumTasks; i++) { tasks.push_back(task()); } coroutines::sync_wait(coroutines::when_all(std::move(tasks))); + auto stop = coroutines::clock_t::now(); - ASSERT_LT(coroutines::clock_t::now() - start, 20ms); + ASSERT_LT(stop - start, TaskDuration + per_task_overhead * NumTasks); }