Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use absl any invocable #484

Merged
merged 6 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/cli/cli_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,7 @@ void CliConnection::close() {
}
channel_->close();
}
auto cb = std::move(disconnect_cb_);
disconnect_cb_ = nullptr;
if (cb) {
cb();
}
on_disconnect();
}

void CliConnection::SendIfNotProcessing() {
Expand Down
28 changes: 18 additions & 10 deletions src/cli/cli_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ Worker::Worker()

Worker::~Worker() {
Stop(std::function<void()>());
start_callback_ = nullptr;
stop_callback_ = nullptr;
work_guard_.reset();
delete private_;
}

void Worker::Start(std::function<void(asio::error_code)> callback) {
void Worker::Start(absl::AnyInvocable<void(asio::error_code)> &&callback) {
DCHECK_EQ(private_->cli_server.get(), nullptr);
DCHECK(!start_callback_);
start_callback_ = std::move(callback);
if (thread_ && thread_->joinable())
thread_->join();

Expand All @@ -64,7 +68,7 @@ void Worker::Start(std::function<void(asio::error_code)> callback) {
PLOG(WARNING) << "failed to set thread priority";
}
});
asio::post(io_context_, [this, callback]() {
asio::post(io_context_, [this]() {
std::string host_name = absl::GetFlag(FLAGS_local_host);
uint16_t port = absl::GetFlag(FLAGS_local_port);

Expand All @@ -75,7 +79,7 @@ void Worker::Start(std::function<void(asio::error_code)> callback) {
asio::ip::tcp::endpoint endpoint(addr, port);
auto results = asio::ip::tcp::resolver::results_type::create(
endpoint, host_name, std::to_string(port));
on_resolve_local(ec, results, callback);
on_resolve_local(ec, results);
return;
}
#ifdef HAVE_C_ARES
Expand All @@ -84,24 +88,27 @@ void Worker::Start(std::function<void(asio::error_code)> callback) {
resolver_.async_resolve(Net_ipv6works() ? asio::ip::tcp::unspec() : asio::ip::tcp::v4(),
host_name, std::to_string(port),
#endif
[this, callback](const asio::error_code& ec,
asio::ip::tcp::resolver::results_type results) {
on_resolve_local(ec, results, callback);
[this](const asio::error_code& ec,
asio::ip::tcp::resolver::results_type results) {
on_resolve_local(ec, results);
});
});
}

void Worker::Stop(std::function<void()> callback) {
void Worker::Stop(absl::AnyInvocable<void()> &&callback) {
DCHECK(!stop_callback_);
stop_callback_ = std::move(callback);
/// stop in the worker thread
if (!thread_) {
return;
}
asio::post(io_context_ ,[this, callback]() {
asio::post(io_context_ ,[this]() {
#ifdef HAVE_C_ARES
resolver_->Cancel();
#else
resolver_.cancel();
#endif
auto callback = std::move(stop_callback_);
if (private_->cli_server) {
private_->cli_server->stop();
}
Expand Down Expand Up @@ -141,8 +148,8 @@ void Worker::WorkFunc() {
}

void Worker::on_resolve_local(asio::error_code ec,
asio::ip::tcp::resolver::results_type results,
std::function<void(asio::error_code)> callback) {
asio::ip::tcp::resolver::results_type results) {
auto callback = std::move(start_callback_);
if (ec) {
LOG(WARNING) << "local resolved host:" << absl::GetFlag(FLAGS_local_host)
<< " failed due to: " << ec;
Expand Down Expand Up @@ -170,6 +177,7 @@ void Worker::on_resolve_local(asio::error_code ec,

if (ec) {
private_->cli_server->stop();
private_->cli_server.reset();
}

work_guard_.reset();
Expand Down
13 changes: 8 additions & 5 deletions src/cli/cli_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

#include "core/cipher.hpp"

#include <functional>
#include <memory>
#include <thread>

#include <absl/functional/any_invocable.h>

#include "config/config.hpp"
#include "core/asio.hpp"
#include "core/logging.hpp"
Expand All @@ -23,8 +24,8 @@ class Worker {
Worker();
~Worker();

void Start(std::function<void(asio::error_code)> callback);
void Stop(std::function<void()> callback);
void Start(absl::AnyInvocable<void(asio::error_code)> &&callback);
void Stop(absl::AnyInvocable<void()> &&callback);

std::string GetDomain() const;
std::string GetRemoteDomain() const;
Expand All @@ -35,8 +36,7 @@ class Worker {
void WorkFunc();

void on_resolve_local(asio::error_code ec,
asio::ip::tcp::resolver::results_type results,
std::function<void(asio::error_code)> callback);
asio::ip::tcp::resolver::results_type results);

asio::io_context io_context_;
/// stopping the io_context from running out of work
Expand All @@ -50,6 +50,9 @@ class Worker {
/// used to do io in another thread
std::unique_ptr<std::thread> thread_;

absl::AnyInvocable<void(asio::error_code)> start_callback_;
absl::AnyInvocable<void()> stop_callback_;

WorkerPrivate *private_;
std::vector<asio::ip::tcp::endpoint> endpoints_;
};
Expand Down
65 changes: 43 additions & 22 deletions src/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#include "net/ssl_server_socket.hpp"
#include "protocol.hpp"

#include <absl/functional/any_invocable.h>

class Downlink {
public:
using io_handle_t = std::function<void(asio::error_code, std::size_t)>;
using handle_t = std::function<void(asio::error_code)>;
using io_handle_t = absl::AnyInvocable<void(asio::error_code, std::size_t)>;
using handle_t = absl::AnyInvocable<void(asio::error_code)>;

Downlink(asio::io_context& io_context)
: io_context_(io_context), socket_(io_context_) {}
Expand All @@ -31,7 +33,7 @@ class Downlink {
}

public:
virtual void handshake(handle_t cb) {
virtual void handshake(handle_t &&cb) {
cb(asio::error_code());
}

Expand All @@ -43,23 +45,23 @@ class Downlink {
return false;
}

virtual void async_read_some(handle_t cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_read, cb);
virtual void async_read_some(handle_t &&cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_read, std::move(cb));
}

virtual size_t read_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) {
return socket_.read_some(tail_buffer(*buf), ec);
}

virtual void async_write_some(handle_t cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_write, cb);
virtual void async_write_some(handle_t &&cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_write, std::move(cb));
}

virtual size_t write_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) {
return socket_.write_some(const_buffer(*buf), ec);
}

virtual void async_shutdown(handle_t cb) {
virtual void async_shutdown(handle_t &&cb) {
asio::error_code ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_send, ec);
cb(ec);
Expand All @@ -82,6 +84,7 @@ class Downlink {
public:
asio::io_context& io_context_;
asio::ip::tcp::socket socket_;
handle_t handshake_callback_; // FIXME handle it gracefully
};

class SSLDownlink : public Downlink {
Expand All @@ -94,12 +97,18 @@ class SSLDownlink : public Downlink {
ssl_socket_(net::SSLServerSocket::Create(&io_context, &socket_, ssl_ctx->native_handle())) {
}

~SSLDownlink() override {}
~SSLDownlink() override { DCHECK(!handshake_callback_); }

void handshake(handle_t cb) override {
ssl_socket_->Handshake([cb](int result) {
void handshake(handle_t &&cb) override {
DCHECK(!handshake_callback_);
handshake_callback_ = std::move(cb);
ssl_socket_->Handshake([this](int result) {
auto callback = std::move(handshake_callback_);
DCHECK(!handshake_callback_);
asio::error_code ec = result == net::OK ? asio::error_code() : asio::error::connection_refused ;
cb(ec);
if (callback) {
callback(ec);
}
});
}

Expand All @@ -114,24 +123,24 @@ class SSLDownlink : public Downlink {
return false;
}

void async_read_some(handle_t cb) override {
ssl_socket_->WaitRead(cb);
void async_read_some(handle_t &&cb) override {
ssl_socket_->WaitRead(std::move(cb));
}

size_t read_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) override {
return ssl_socket_->Read(buf, ec);
}

void async_write_some(handle_t cb) override {
ssl_socket_->WaitWrite(cb);
void async_write_some(handle_t &&cb) override {
ssl_socket_->WaitWrite(std::move(cb));
}

size_t write_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) override {
return ssl_socket_->Write(buf, ec);
}

void async_shutdown(handle_t cb) override {
ssl_socket_->Shutdown(cb);
void async_shutdown(handle_t &&cb) override {
ssl_socket_->Shutdown(std::move(cb));
}

void shutdown(asio::error_code &ec) override {
Expand Down Expand Up @@ -242,7 +251,18 @@ class Connection {
/// set callback
///
/// \param cb the callback function pointer when disconnect happens
void set_disconnect_cb(std::function<void()> cb) { disconnect_cb_ = cb; }
void set_disconnect_cb(absl::AnyInvocable<void()> &&cb) { disconnect_cb_ = std::move(cb); }

/// call callback
///
void on_disconnect() {
downlink_->handshake_callback_ = nullptr;
auto cb = std::move(disconnect_cb_);
DCHECK(!disconnect_cb_);
if (cb) {
cb();
}
}

asio::io_context& io_context() { return *io_context_; }

Expand Down Expand Up @@ -294,14 +314,15 @@ class Connection {

std::unique_ptr<Downlink> downlink_;

/// the callback invoked when disconnect event happens
std::function<void()> disconnect_cb_;

protected:
/// statistics of read bytes
size_t rbytes_transferred_ = 0;
/// statistics of written bytes
size_t wbytes_transferred_ = 0;

private:
/// the callback invoked when disconnect event happens
absl::AnyInvocable<void()> disconnect_cb_;
};

class ConnectionFactory {
Expand Down
Loading
Loading