diff --git a/src/cli/cli_worker.cpp b/src/cli/cli_worker.cpp index 7928b0754..040509041 100644 --- a/src/cli/cli_worker.cpp +++ b/src/cli/cli_worker.cpp @@ -45,12 +45,16 @@ Worker::Worker() Worker::~Worker() { Stop(std::function()); + start_callback_ = nullptr; + stop_callback_ = nullptr; work_guard_.reset(); delete private_; } -void Worker::Start(std::function callback) { +void Worker::Start(absl::AnyInvocable &&callback) { DCHECK_EQ(private_->cli_server.get(), nullptr); + DCHECK(!start_callback_); + start_callback_ = std::move(callback); if (thread_ && thread_->joinable()) thread_->join(); @@ -64,7 +68,7 @@ void Worker::Start(std::function callback) { PLOG(WARNING) << "failed to set thread priority"; } }); - asio::post(io_context_, [this, callback]() { + asio::post(io_context_, [this]() { std::string host_name = absl::GetFlag(FLAGS_local_host); uint16_t port = absl::GetFlag(FLAGS_local_port); @@ -75,7 +79,7 @@ void Worker::Start(std::function callback) { asio::ip::tcp::endpoint endpoint(addr, port); auto results = asio::ip::tcp::resolver::results_type::create( endpoint, host_name, std::to_string(port)); - on_resolve_local(ec, results, callback); + on_resolve_local(ec, results); return; } #ifdef HAVE_C_ARES @@ -84,24 +88,27 @@ void Worker::Start(std::function callback) { resolver_.async_resolve(Net_ipv6works() ? asio::ip::tcp::unspec() : asio::ip::tcp::v4(), host_name, std::to_string(port), #endif - [this, callback](const asio::error_code& ec, - asio::ip::tcp::resolver::results_type results) { - on_resolve_local(ec, results, callback); + [this](const asio::error_code& ec, + asio::ip::tcp::resolver::results_type results) { + on_resolve_local(ec, results); }); }); } -void Worker::Stop(std::function callback) { +void Worker::Stop(absl::AnyInvocable &&callback) { + DCHECK(!stop_callback_); + stop_callback_ = std::move(callback); /// stop in the worker thread if (!thread_) { return; } - asio::post(io_context_ ,[this, callback]() { + asio::post(io_context_ ,[this]() { #ifdef HAVE_C_ARES resolver_->Cancel(); #else resolver_.cancel(); #endif + auto callback = std::move(stop_callback_); if (private_->cli_server) { private_->cli_server->stop(); } @@ -141,8 +148,8 @@ void Worker::WorkFunc() { } void Worker::on_resolve_local(asio::error_code ec, - asio::ip::tcp::resolver::results_type results, - std::function callback) { + asio::ip::tcp::resolver::results_type results) { + auto callback = std::move(start_callback_); if (ec) { LOG(WARNING) << "local resolved host:" << absl::GetFlag(FLAGS_local_host) << " failed due to: " << ec; diff --git a/src/cli/cli_worker.hpp b/src/cli/cli_worker.hpp index 1d978be93..592b51d21 100644 --- a/src/cli/cli_worker.hpp +++ b/src/cli/cli_worker.hpp @@ -5,10 +5,11 @@ #include "core/cipher.hpp" -#include #include #include +#include + #include "config/config.hpp" #include "core/asio.hpp" #include "core/logging.hpp" @@ -23,8 +24,8 @@ class Worker { Worker(); ~Worker(); - void Start(std::function callback); - void Stop(std::function callback); + void Start(absl::AnyInvocable &&callback); + void Stop(absl::AnyInvocable &&callback); std::string GetDomain() const; std::string GetRemoteDomain() const; @@ -35,8 +36,7 @@ class Worker { void WorkFunc(); void on_resolve_local(asio::error_code ec, - asio::ip::tcp::resolver::results_type results, - std::function callback); + asio::ip::tcp::resolver::results_type results); asio::io_context io_context_; /// stopping the io_context from running out of work @@ -50,6 +50,9 @@ class Worker { /// used to do io in another thread std::unique_ptr thread_; + absl::AnyInvocable start_callback_; + absl::AnyInvocable stop_callback_; + WorkerPrivate *private_; std::vector endpoints_; };