Skip to content

Commit

Permalink
support select coro and channel (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
qicosmos authored Mar 16, 2024
1 parent 252c4a7 commit e7c3527
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 9 deletions.
35 changes: 30 additions & 5 deletions include/cinatra/ylt/coro_io/coro_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <async_simple/coro/Sleep.h>
#include <async_simple/coro/SyncAwait.h>

#include "async_simple/coro/Collect.h"

#if defined(YLT_ENABLE_SSL) || defined(CINATRA_ENABLE_SSL)
#include <asio/ssl.hpp>
#endif
Expand Down Expand Up @@ -333,8 +335,24 @@ post(Func func,
co_return co_await awaitor.await_resume(helper);
}

template <typename R>
struct coro_channel
: public asio::experimental::channel<void(std::error_code, R)> {
using return_type = R;
using ValueType = std::pair<std::error_code, R>;
using asio::experimental::channel<void(std::error_code, R)>::channel;
};

template <typename R>
inline coro_channel<R> create_channel(
size_t capacity,
asio::io_context::executor_type executor =
coro_io::get_global_block_executor()->get_asio_executor()) {
return coro_channel<R>(executor, capacity);
}

template <typename T>
async_simple::coro::Lazy<std::error_code> async_send(
inline async_simple::coro::Lazy<std::error_code> async_send(
asio::experimental::channel<void(std::error_code, T)> &channel, T val) {
callback_awaitor<std::error_code> awaitor;
co_return co_await awaitor.await_resume(
Expand All @@ -345,17 +363,24 @@ async_simple::coro::Lazy<std::error_code> async_send(
});
}

template <typename R>
async_simple::coro::Lazy<std::pair<std::error_code, R>> async_receive(
asio::experimental::channel<void(std::error_code, R)> &channel) {
callback_awaitor<std::pair<std::error_code, R>> awaitor;
template <typename Channel>
async_simple::coro::Lazy<std::pair<
std::error_code,
typename Channel::return_type>> inline async_receive(Channel &channel) {
callback_awaitor<std::pair<std::error_code, typename Channel::return_type>>
awaitor;
co_return co_await awaitor.await_resume([&](auto handler) {
channel.async_receive([handler](auto ec, auto val) {
handler.set_value_then_resume(std::make_pair(ec, std::move(val)));
});
});
}

template <typename... T>
auto select(T &&...args) {
return async_simple::coro::collectAny(std::forward<T>(args)...);
}

template <typename Socket, typename AsioBuffer>
std::pair<asio::error_code, size_t> read_some(Socket &sock,
AsioBuffer &&buffer) {
Expand Down
117 changes: 113 additions & 4 deletions tests/test_cinatra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ TEST_CASE("test cinatra::string SSO to no SSO") {
}

TEST_CASE("test coro channel") {
auto ctx = coro_io::get_global_block_executor()->get_asio_executor();
asio::experimental::channel<void(std::error_code, int)> ch(ctx, 10000);
auto ch = coro_io::create_channel<int>(1000);
auto ec = async_simple::coro::syncAwait(coro_io::async_send(ch, 41));
CHECK(!ec);
ec = async_simple::coro::syncAwait(coro_io::async_send(ch, 42));
Expand All @@ -198,16 +197,126 @@ TEST_CASE("test coro channel") {
std::error_code err;
int val;
std::tie(err, val) =
async_simple::coro::syncAwait(coro_io::async_receive<int>(ch));
async_simple::coro::syncAwait(coro_io::async_receive(ch));
CHECK(!err);
CHECK(val == 41);

std::tie(err, val) =
async_simple::coro::syncAwait(coro_io::async_receive<int>(ch));
async_simple::coro::syncAwait(coro_io::async_receive(ch));
CHECK(!err);
CHECK(val == 42);
}

async_simple::coro::Lazy<void> test_select_channel() {
using namespace coro_io;
using namespace async_simple;
using namespace async_simple::coro;

auto ch1 = coro_io::create_channel<int>(1000);
auto ch2 = coro_io::create_channel<int>(1000);

co_await async_send(ch1, 41);
co_await async_send(ch2, 42);

std::array<int, 2> arr{41, 42};
int val;

size_t index =
co_await select(std::pair{async_receive(ch1),
[&val](auto value) {
auto [ec, r] = value.value();
val = r;
}},
std::pair{async_receive(ch2), [&val](auto value) {
auto [ec, r] = value.value();
val = r;
}});

CHECK(val == arr[index]);

co_await async_send(ch1, 41);
co_await async_send(ch2, 42);

std::vector<Lazy<std::pair<std::error_code, int>>> vec;
vec.push_back(async_receive(ch1));
vec.push_back(async_receive(ch2));

index = co_await select(std::move(vec), [&](size_t i, auto result) {
val = result.value().second;
});
CHECK(val == arr[index]);

period_timer timer1(coro_io::get_global_executor());
timer1.expires_after(100ms);
period_timer timer2(coro_io::get_global_executor());
timer2.expires_after(200ms);

int val1;
index = co_await select(std::pair{timer1.async_await(),
[&](auto val) {
CHECK(val.value());
val1 = 0;
}},
std::pair{timer2.async_await(), [&](auto val) {
CHECK(val.value());
val1 = 0;
}});
CHECK(index == val1);

int val2;
index = co_await select(std::pair{coro_io::post([] {
}),
[&](auto) {
std::cout << "post1\n";
val2 = 0;
}},
std::pair{coro_io::post([] {
}),
[&](auto) {
std::cout << "post2\n";
val2 = 1;
}});
CHECK(index == val2);

co_await async_send(ch1, 43);
auto lazy = coro_io::post([] {
});

int val3 = -1;
index = co_await select(std::pair{async_receive(ch1),
[&](auto result) {
val3 = result.value().second;
}},
std::pair{std::move(lazy), [&](auto) {
val3 = 0;
}});

if (index == 0) {
CHECK(val3 == 43);
}
else if (index == 1) {
CHECK(val3 == 0);
}
}

TEST_CASE("test select coro channel") {
using namespace coro_io;
async_simple::coro::syncAwait(test_select_channel());

auto ch = coro_io::create_channel<int>(1000);

async_simple::coro::syncAwait(coro_io::async_send(ch, 41));
async_simple::coro::syncAwait(coro_io::async_send(ch, 42));

std::error_code ec;
int val;
std::tie(ec, val) = async_simple::coro::syncAwait(coro_io::async_receive(ch));
CHECK(val == 41);

std::tie(ec, val) = async_simple::coro::syncAwait(coro_io::async_receive(ch));
CHECK(val == 42);
}

async_simple::coro::Lazy<void> test_collect_all() {
asio::io_context ioc;
std::thread thd([&] {
Expand Down

0 comments on commit e7c3527

Please sign in to comment.