Skip to content

Commit

Permalink
Improved Task Lifecycle (#259)
Browse files Browse the repository at this point in the history
This update allows the Task to transfer ownership of the coroutine handle to the `awaiter` returned by the Task's `operator co_await()` method.

Without this update [the linked test would fail on this line](https://github.com/nv-morpheus/MRC/compare/branch-23.01...ryanolson:task_lifecycle?expand=1#diff-abd2fa1b390dadc99ee7076e518e5ba9029bb0618c908af2ba54134c519f7277R135), because the Task is created in the Awaitable's `operator co_await()` method would not transfer ownership of the coroutines handle to the awaitable, so the Task's destructor would destroy the coroutine who's awaited the caller is awaiting on.

Authors:
  - Ryan Olson (https://github.com/ryanolson)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #259
  • Loading branch information
ryanolson authored Jan 14, 2023
1 parent 9dc7cb2 commit 5a1a5f4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 9 deletions.
4 changes: 3 additions & 1 deletion cpp/mrc/include/mrc/coroutines/sync_wait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

#include <condition_variable>
#include <mutex>
#include <type_traits>
#include <utility>

namespace mrc::coroutines {

Expand Down Expand Up @@ -243,7 +245,7 @@ class SyncWaitTask
}
else
{
return m_coroutine.promise().result();
return std::remove_reference_t<ReturnT>{std::move(m_coroutine.promise().result())};
}
}

Expand Down
24 changes: 18 additions & 6 deletions cpp/mrc/include/mrc/coroutines/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct Promise final : public PromiseBase
m_return_value = std::move(value);
}

auto result() const& -> const ReturnT&
auto result() & -> ReturnT&
{
if (m_exception_ptr)
{
Expand Down Expand Up @@ -219,7 +219,7 @@ class [[nodiscard]] Task
{
if (std::addressof(other) != this)
{
if (m_coroutine != nullptr)
if (m_coroutine)
{
m_coroutine.destroy();
}
Expand Down Expand Up @@ -252,7 +252,7 @@ class [[nodiscard]] Task

auto destroy() -> bool
{
if (m_coroutine != nullptr)
if (m_coroutine)
{
m_coroutine.destroy();
m_coroutine = nullptr;
Expand All @@ -262,7 +262,7 @@ class [[nodiscard]] Task
return false;
}

auto operator co_await() const& noexcept
auto operator co_await() const&
{
struct Awaitable : public AwaitableBase
{
Expand All @@ -276,18 +276,28 @@ class [[nodiscard]] Task
}
else
{
// returns a reference to the value held by the promise
return this->m_coroutine.promise().result();
}
}
};

// the task is responsible for the destruction of the coroutine and promise
return Awaitable{m_coroutine};
}

auto operator co_await() const&& noexcept
auto operator co_await() &&
{
struct Awaitable : public AwaitableBase
{
~Awaitable()
{
if (this->m_coroutine)
{
this->m_coroutine.destroy();
}
}

auto await_resume() -> decltype(auto)
{
if constexpr (std::is_same_v<void, ReturnT>)
Expand All @@ -298,12 +308,14 @@ class [[nodiscard]] Task
}
else
{
// moves the value held by the promise to the caller
return std::move(this->m_coroutine.promise()).result();
}
}
};

return Awaitable{m_coroutine};
// the awaiter is responsible for the destruction of the coroutine and promise
return Awaitable{std::exchange(m_coroutine, nullptr)};
}

auto promise() & -> promise_type&
Expand Down
4 changes: 2 additions & 2 deletions cpp/mrc/include/mrc/coroutines/when_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class when_all_task

~when_all_task()
{
if (m_coroutine != nullptr)
if (m_coroutine)
{
m_coroutine.destroy();
}
Expand Down Expand Up @@ -526,7 +526,7 @@ class when_all_task
}
else
{
return m_coroutine.promise().return_value();
return std::remove_reference_t<ReturnT>{std::move(m_coroutine.promise().return_value())};
}
}

Expand Down
36 changes: 36 additions & 0 deletions cpp/mrc/tests/coroutines/test_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
* limitations under the License.
*/

#include "mrc/core/std23_expected.hpp"
#include "mrc/core/thread.hpp"
#include "mrc/coroutines/ring_buffer.hpp"
#include "mrc/coroutines/sync_wait.hpp"
Expand Down Expand Up @@ -118,3 +119,38 @@ TEST_F(TestCoroTask, RingBufferStressTest)
coroutines::sync_wait(coroutines::when_all(source(), sink()));
}
}

// this is our awaitable
class AwaitableTaskProvider
{
public:
struct Done
{};

AwaitableTaskProvider()
{
m_task_generator = []() -> coroutines::Task<std23::expected<int, Done>> { co_return{42}; };
}

auto operator co_await() -> decltype(auto)
{
return m_task_generator().operator co_await();
}

private:
std::function<coroutines::Task<std23::expected<int, Done>>()> m_task_generator;
};

TEST_F(TestCoroTask, AwaitableTaskProvider)
{
auto expected = coroutines::sync_wait(AwaitableTaskProvider{});
EXPECT_EQ(*expected, 42);

auto task = []() -> coroutines::Task<void> {
auto expected = co_await AwaitableTaskProvider{};
EXPECT_EQ(*expected, 42);
co_return;
};

coroutines::sync_wait(task());
}

0 comments on commit 5a1a5f4

Please sign in to comment.