Skip to content

Commit

Permalink
Merge pull request #484 from Chilledheart/use_absl_any_invocable
Browse files Browse the repository at this point in the history
Use absl any invocable
  • Loading branch information
Chilledheart authored Dec 4, 2023
2 parents 731eea4 + 03f7196 commit b6cb33c
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 141 deletions.
6 changes: 1 addition & 5 deletions src/cli/cli_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,7 @@ void CliConnection::close() {
}
channel_->close();
}
auto cb = std::move(disconnect_cb_);
disconnect_cb_ = nullptr;
if (cb) {
cb();
}
on_disconnect();
}

void CliConnection::SendIfNotProcessing() {
Expand Down
28 changes: 18 additions & 10 deletions src/cli/cli_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ Worker::Worker()

Worker::~Worker() {
Stop(std::function<void()>());
start_callback_ = nullptr;
stop_callback_ = nullptr;
work_guard_.reset();
delete private_;
}

void Worker::Start(std::function<void(asio::error_code)> callback) {
void Worker::Start(absl::AnyInvocable<void(asio::error_code)> &&callback) {
DCHECK_EQ(private_->cli_server.get(), nullptr);
DCHECK(!start_callback_);
start_callback_ = std::move(callback);
if (thread_ && thread_->joinable())
thread_->join();

Expand All @@ -64,7 +68,7 @@ void Worker::Start(std::function<void(asio::error_code)> callback) {
PLOG(WARNING) << "failed to set thread priority";
}
});
asio::post(io_context_, [this, callback]() {
asio::post(io_context_, [this]() {
std::string host_name = absl::GetFlag(FLAGS_local_host);
uint16_t port = absl::GetFlag(FLAGS_local_port);

Expand All @@ -75,7 +79,7 @@ void Worker::Start(std::function<void(asio::error_code)> callback) {
asio::ip::tcp::endpoint endpoint(addr, port);
auto results = asio::ip::tcp::resolver::results_type::create(
endpoint, host_name, std::to_string(port));
on_resolve_local(ec, results, callback);
on_resolve_local(ec, results);
return;
}
#ifdef HAVE_C_ARES
Expand All @@ -84,24 +88,27 @@ void Worker::Start(std::function<void(asio::error_code)> callback) {
resolver_.async_resolve(Net_ipv6works() ? asio::ip::tcp::unspec() : asio::ip::tcp::v4(),
host_name, std::to_string(port),
#endif
[this, callback](const asio::error_code& ec,
asio::ip::tcp::resolver::results_type results) {
on_resolve_local(ec, results, callback);
[this](const asio::error_code& ec,
asio::ip::tcp::resolver::results_type results) {
on_resolve_local(ec, results);
});
});
}

void Worker::Stop(std::function<void()> callback) {
void Worker::Stop(absl::AnyInvocable<void()> &&callback) {
DCHECK(!stop_callback_);
stop_callback_ = std::move(callback);
/// stop in the worker thread
if (!thread_) {
return;
}
asio::post(io_context_ ,[this, callback]() {
asio::post(io_context_ ,[this]() {
#ifdef HAVE_C_ARES
resolver_->Cancel();
#else
resolver_.cancel();
#endif
auto callback = std::move(stop_callback_);
if (private_->cli_server) {
private_->cli_server->stop();
}
Expand Down Expand Up @@ -141,8 +148,8 @@ void Worker::WorkFunc() {
}

void Worker::on_resolve_local(asio::error_code ec,
asio::ip::tcp::resolver::results_type results,
std::function<void(asio::error_code)> callback) {
asio::ip::tcp::resolver::results_type results) {
auto callback = std::move(start_callback_);
if (ec) {
LOG(WARNING) << "local resolved host:" << absl::GetFlag(FLAGS_local_host)
<< " failed due to: " << ec;
Expand Down Expand Up @@ -170,6 +177,7 @@ void Worker::on_resolve_local(asio::error_code ec,

if (ec) {
private_->cli_server->stop();
private_->cli_server.reset();
}

work_guard_.reset();
Expand Down
13 changes: 8 additions & 5 deletions src/cli/cli_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

#include "core/cipher.hpp"

#include <functional>
#include <memory>
#include <thread>

#include <absl/functional/any_invocable.h>

#include "config/config.hpp"
#include "core/asio.hpp"
#include "core/logging.hpp"
Expand All @@ -23,8 +24,8 @@ class Worker {
Worker();
~Worker();

void Start(std::function<void(asio::error_code)> callback);
void Stop(std::function<void()> callback);
void Start(absl::AnyInvocable<void(asio::error_code)> &&callback);
void Stop(absl::AnyInvocable<void()> &&callback);

std::string GetDomain() const;
std::string GetRemoteDomain() const;
Expand All @@ -35,8 +36,7 @@ class Worker {
void WorkFunc();

void on_resolve_local(asio::error_code ec,
asio::ip::tcp::resolver::results_type results,
std::function<void(asio::error_code)> callback);
asio::ip::tcp::resolver::results_type results);

asio::io_context io_context_;
/// stopping the io_context from running out of work
Expand All @@ -50,6 +50,9 @@ class Worker {
/// used to do io in another thread
std::unique_ptr<std::thread> thread_;

absl::AnyInvocable<void(asio::error_code)> start_callback_;
absl::AnyInvocable<void()> stop_callback_;

WorkerPrivate *private_;
std::vector<asio::ip::tcp::endpoint> endpoints_;
};
Expand Down
65 changes: 43 additions & 22 deletions src/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#include "net/ssl_server_socket.hpp"
#include "protocol.hpp"

#include <absl/functional/any_invocable.h>

class Downlink {
public:
using io_handle_t = std::function<void(asio::error_code, std::size_t)>;
using handle_t = std::function<void(asio::error_code)>;
using io_handle_t = absl::AnyInvocable<void(asio::error_code, std::size_t)>;
using handle_t = absl::AnyInvocable<void(asio::error_code)>;

Downlink(asio::io_context& io_context)
: io_context_(io_context), socket_(io_context_) {}
Expand All @@ -31,7 +33,7 @@ class Downlink {
}

public:
virtual void handshake(handle_t cb) {
virtual void handshake(handle_t &&cb) {
cb(asio::error_code());
}

Expand All @@ -43,23 +45,23 @@ class Downlink {
return false;
}

virtual void async_read_some(handle_t cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_read, cb);
virtual void async_read_some(handle_t &&cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_read, std::move(cb));
}

virtual size_t read_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) {
return socket_.read_some(tail_buffer(*buf), ec);
}

virtual void async_write_some(handle_t cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_write, cb);
virtual void async_write_some(handle_t &&cb) {
socket_.async_wait(asio::ip::tcp::socket::wait_write, std::move(cb));
}

virtual size_t write_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) {
return socket_.write_some(const_buffer(*buf), ec);
}

virtual void async_shutdown(handle_t cb) {
virtual void async_shutdown(handle_t &&cb) {
asio::error_code ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_send, ec);
cb(ec);
Expand All @@ -82,6 +84,7 @@ class Downlink {
public:
asio::io_context& io_context_;
asio::ip::tcp::socket socket_;
handle_t handshake_callback_; // FIXME handle it gracefully
};

class SSLDownlink : public Downlink {
Expand All @@ -94,12 +97,18 @@ class SSLDownlink : public Downlink {
ssl_socket_(net::SSLServerSocket::Create(&io_context, &socket_, ssl_ctx->native_handle())) {
}

~SSLDownlink() override {}
~SSLDownlink() override { DCHECK(!handshake_callback_); }

void handshake(handle_t cb) override {
ssl_socket_->Handshake([cb](int result) {
void handshake(handle_t &&cb) override {
DCHECK(!handshake_callback_);
handshake_callback_ = std::move(cb);
ssl_socket_->Handshake([this](int result) {
auto callback = std::move(handshake_callback_);
DCHECK(!handshake_callback_);
asio::error_code ec = result == net::OK ? asio::error_code() : asio::error::connection_refused ;
cb(ec);
if (callback) {
callback(ec);
}
});
}

Expand All @@ -114,24 +123,24 @@ class SSLDownlink : public Downlink {
return false;
}

void async_read_some(handle_t cb) override {
ssl_socket_->WaitRead(cb);
void async_read_some(handle_t &&cb) override {
ssl_socket_->WaitRead(std::move(cb));
}

size_t read_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) override {
return ssl_socket_->Read(buf, ec);
}

void async_write_some(handle_t cb) override {
ssl_socket_->WaitWrite(cb);
void async_write_some(handle_t &&cb) override {
ssl_socket_->WaitWrite(std::move(cb));
}

size_t write_some(std::shared_ptr<IOBuf> buf, asio::error_code &ec) override {
return ssl_socket_->Write(buf, ec);
}

void async_shutdown(handle_t cb) override {
ssl_socket_->Shutdown(cb);
void async_shutdown(handle_t &&cb) override {
ssl_socket_->Shutdown(std::move(cb));
}

void shutdown(asio::error_code &ec) override {
Expand Down Expand Up @@ -242,7 +251,18 @@ class Connection {
/// set callback
///
/// \param cb the callback function pointer when disconnect happens
void set_disconnect_cb(std::function<void()> cb) { disconnect_cb_ = cb; }
void set_disconnect_cb(absl::AnyInvocable<void()> &&cb) { disconnect_cb_ = std::move(cb); }

/// call callback
///
void on_disconnect() {
downlink_->handshake_callback_ = nullptr;
auto cb = std::move(disconnect_cb_);
DCHECK(!disconnect_cb_);
if (cb) {
cb();
}
}

asio::io_context& io_context() { return *io_context_; }

Expand Down Expand Up @@ -294,14 +314,15 @@ class Connection {

std::unique_ptr<Downlink> downlink_;

/// the callback invoked when disconnect event happens
std::function<void()> disconnect_cb_;

protected:
/// statistics of read bytes
size_t rbytes_transferred_ = 0;
/// statistics of written bytes
size_t wbytes_transferred_ = 0;

private:
/// the callback invoked when disconnect event happens
absl::AnyInvocable<void()> disconnect_cb_;
};

class ConnectionFactory {
Expand Down
Loading

0 comments on commit b6cb33c

Please sign in to comment.