From 87719510bee0656f004529639846cbe20b013e19 Mon Sep 17 00:00:00 2001 From: vitsai Date: Sat, 1 Jul 2023 00:16:00 -0700 Subject: [PATCH] fix other direction too oops Signed-off-by: vitsai --- src/ray/common/asio/instrumented_io_context.h | 33 +++++++++++++++---- src/ray/gcs/gcs_client/gcs_client.cc | 2 +- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/ray/common/asio/instrumented_io_context.h b/src/ray/common/asio/instrumented_io_context.h index 737193543462..df05397cc021 100644 --- a/src/ray/common/asio/instrumented_io_context.h +++ b/src/ray/common/asio/instrumented_io_context.h @@ -29,26 +29,46 @@ class instrumented_io_context : public boost::asio::io_context { /// Initializes the global stats struct after calling the base contructor. /// TODO(ekl) allow taking an externally defined event tracker. instrumented_io_context() - : event_stats_(std::make_shared()), is_running_(false) {} + : event_stats_(std::make_shared()), run_count_(0) {} - bool running() { return is_running_.load(); } + void stop_if_solo() { + absl::MutexLock l(&mu_); + if (run_count_ == 1) { + run_count_ = 0; + boost::asio::io_context::stop(); + } + } bool run_if_stopped(std::function callback) { - if (!is_running_.exchange(true)) { + size_t old_run_count; + { + absl::MutexLock l(&mu_); + old_run_count = run_count_; + run_count_++; + } + + if (old_run_count == 0) { callback(); boost::asio::io_context::run(); return true; + } else { + absl::MutexLock l(&mu_); + run_count_--; } return false; } void run() { - is_running_.store(true); + { + absl::MutexLock l(&mu_); + run_count_++; + } boost::asio::io_context::run(); } void stop() { - is_running_.store(false); + absl::MutexLock l(&mu_); + run_count_ = 0; boost::asio::io_context::stop(); } @@ -81,5 +101,6 @@ class instrumented_io_context : public boost::asio::io_context { /// The event stats tracker to use to record asio handler stats to. std::shared_ptr event_stats_; - std::atomic is_running_; + absl::Mutex mu_; + size_t run_count_ GUARDED_BY(mu_); }; diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index b19e76645184..a0b1a4c87ac4 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -104,7 +104,7 @@ Status GcsClient::Connect(instrumented_io_context &io_service, RAY_LOG(DEBUG) << "Setting cluster ID to " << cluster_id; client_call_manager_->SetClusterId(cluster_id); if (do_stop.get()) { - io_service.stop(); + io_service.stop_if_solo(); } else { wait_sync.set_value(true); }