From b1db3cef925973e744d9030e0fac6952d7fe132b Mon Sep 17 00:00:00 2001 From: jonb377 Date: Wed, 1 Nov 2023 09:50:36 -0700 Subject: [PATCH] Support PreemptionSyncManager in XlaCoordinator (#5733) * Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling --- test/spmd/test_xla_distributed_checkpoint.py | 40 +++++++++++++- torch_xla/csrc/BUILD | 1 + torch_xla/csrc/init_python_bindings.cpp | 40 ++++++++++++++ torch_xla/csrc/runtime/BUILD | 10 ++-- torch_xla/csrc/runtime/computation_client.h | 17 ++++++ torch_xla/csrc/runtime/distributed_runtime.h | 38 ------------- .../csrc/runtime/pjrt_computation_client.cc | 42 ++++++++++++--- .../csrc/runtime/pjrt_computation_client.h | 10 ++++ ...tributed_runtime.cc => xla_coordinator.cc} | 39 ++++++++++---- torch_xla/csrc/runtime/xla_coordinator.h | 53 +++++++++++++++++++ 10 files changed, 230 insertions(+), 60 deletions(-) delete mode 100644 torch_xla/csrc/runtime/distributed_runtime.h rename torch_xla/csrc/runtime/{distributed_runtime.cc => xla_coordinator.cc} (53%) create mode 100644 torch_xla/csrc/runtime/xla_coordinator.h diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 76ea6b71672d..29ed825d015b 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,14 +1,17 @@ import functools import os +import signal import sys import tempfile -import unittest import test_xla_sharding_base import threading +import time +import unittest import torch import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr import torch_xla.experimental.xla_sharding as xs @@ -339,7 +342,12 @@ def setUp(self): super().setUp() # Initialize the a minimal process group dist.init_process_group( - backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0) + backend='gloo', + init_method='tcp://localhost:8932', + world_size=1, + rank=0) + torch_xla._XLAC._ensure_xla_coordinator_initialized( + global_rank=0, world_size=1, master_addr="localhost") def tearDown(self): super().tearDown() @@ -486,6 +494,34 @@ def test_master_ip_discovery(self, patched_get_worker_ips): patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2'] self.assertTrue(xr.get_master_ip(), '10.0.0.1') + def test_preemption_sync_manager(self): + try: + torch_xla._XLAC._activate_preemption_sync_manager() + sync_point_reached = torch_xla._XLAC._sync_point_reached + + # No sync point for the first several steps + sigterm_step = 10 + for step in range(sigterm_step): + self.assertFalse(sync_point_reached(step)) + + # Send a SIGTERM to the current process to trigger a sync point + os.kill(os.getpid(), signal.SIGTERM) + + # Allow the signal to be processed. The PreemptionSyncManager must receive + # the SIGTERM, which happens asynchronously, and the state must be + # propagated through the distributed runtime. Eventually, + # sync_point_reached will return True. + success = False + for attempt in range(10): + success = sync_point_reached(sigterm_step + attempt) + if success: + break + time.sleep(1) + self.assertTrue(success, "Sync point was never reached after SIGTERM") + finally: + # Scope the PreemptionSyncManager to the lifespan of the test. + torch_xla._XLAC._deactivate_preemption_sync_manager() + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 352db2d34fbd..d4eed1ec83f2 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -259,6 +259,7 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:sys_util", "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc/runtime:util", + "//torch_xla/csrc/runtime:xla_coordinator", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1e6bb020fe50..3e2b32cc2d8f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -47,6 +47,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_impl.h" @@ -1876,6 +1877,45 @@ void InitXlaModuleBindings(py::module m) { xla::HloModule::CreateFromProto(module_proto, config).value()); return module->ToString(); }); + // Initialize the XlaCoordinator in the runtime if not already initialized. + m.def("_ensure_xla_coordinator_initialized", + [](int global_rank, int world_size, std::string master_addr, + std::string master_port) { + auto comp_client = runtime::GetComputationClient(); + if (!comp_client->CoordinatorInitialized()) { + runtime::GetComputationClient()->InitializeCoordinator( + global_rank, world_size, master_addr, master_port); + } + }, + py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"), + py::arg("master_port") = + runtime::XlaCoordinator::kDefaultCoordinatorPort); + // Create a PreemptionSyncManager for the XlaCoordinator. The + // PreemptionSyncManager will register a SIGTERM handler as a side effect. + m.def("_activate_preemption_sync_manager", []() { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + coordinator.ActivatePreemptionSyncManager(); + }); + // Deactivate the PreemptionSyncManager in the XlaCoordinator if one is active + m.def("_deactivate_preemption_sync_manager", []() { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + coordinator.DeactivatePreemptionSyncManager(); + }); + // Check whether a sync point has been reached. This method requires that the + // distributed runtime be initialized and a PreemptionSyncManager activated. + m.def("_sync_point_reached", [](int step) { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + return coordinator.ReachedSyncPoint(step); + }); m.def("_is_placecholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); return xtensor->CurrentDataHandle() && diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index db054e32289b..da12492d2cdc 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -63,6 +63,7 @@ cc_library( ":sys_util", ":types", ":util", + ":xla_coordinator", "//torch_xla/csrc:device", "@tsl//tsl/platform:stacktrace_handler", "@xla//xla:literal_util", @@ -88,12 +89,12 @@ cc_library( deps = [ ":computation_client", ":debug_macros", - ":distributed_runtime", ":env_vars", ":multi_wait", ":stablehlo_helper", ":tf_logging", ":thread_pool", + ":xla_coordinator", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", @@ -165,13 +166,14 @@ cc_library( ) cc_library( - name = "distributed_runtime", - srcs = ["distributed_runtime.cc"], - hdrs = ["distributed_runtime.h"], + name = "xla_coordinator", + srcs = ["xla_coordinator.cc"], + hdrs = ["xla_coordinator.h"], deps = [ ":debug_macros", ":sys_util", "@xla//xla/pjrt/distributed", + "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", ], ) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index db4bac21916a..145a6d0aa091 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -30,6 +30,12 @@ namespace torch_xla { namespace runtime { +// Forward declare XlaCoordinator to avoid logging macro redefinition from the +// transitively included PJRT header. +// TODO(jonbolin): We need a way to ensure the right macros are included +// regardless of the import order. +class XlaCoordinator; + // Somehow the compiler doesn't allow type that has default member being // used as a default parameter in a method defined in the same scope. // Therefore, ClientExecuteOptions is defined here instead of within @@ -348,6 +354,17 @@ class ComputationClient { // the local devices will be waited for. virtual void WaitDeviceOps(const std::vector& devices) = 0; + // Check whether the XlaCoordinator has been initialized. + virtual bool CoordinatorInitialized() const = 0; + + // Initialize the XlaCoordinator for the runtime. + virtual void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) = 0; + + // Return the XlaCoordinator for the runtime. + virtual XlaCoordinator& GetCoordinator() = 0; + // Utility API around the vector based Compile() API to compile a single // computation. ComputationPtr Compile(xla::XlaComputation computation, diff --git a/torch_xla/csrc/runtime/distributed_runtime.h b/torch_xla/csrc/runtime/distributed_runtime.h deleted file mode 100644 index f26ef3d008c2..000000000000 --- a/torch_xla/csrc/runtime/distributed_runtime.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef XLA_CLIENT_DISTRIBUTED_RUNTIME_H_ -#define XLA_CLIENT_DISTRIBUTED_RUNTIME_H_ - -#include - -#include "xla/pjrt/distributed/distributed.h" - -namespace torch_xla { -namespace runtime { - -class DistributedRuntime { - public: - static const std::string default_coordinator_port; - static DistributedRuntime& getInstance(int global_rank, - std::string master_addr, - std::string port) { - static DistributedRuntime dist_runtime_instance(global_rank, master_addr, - port); - return dist_runtime_instance; - } - ~DistributedRuntime(); - DistributedRuntime(DistributedRuntime const&) = delete; - void operator=(DistributedRuntime const&) = delete; - - std::shared_ptr GetClient(); - - private: - DistributedRuntime(int global_rank, std::string master_addr, - std::string port); - - std::unique_ptr dist_runtime_service_; - std::shared_ptr dist_runtime_client_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_DISTRIBUTED_RUNTIME_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 2ae0768856b7..c003f4f97060 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -9,12 +9,12 @@ #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/distributed_runtime.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" @@ -109,13 +109,18 @@ PjRtComputationClient::PjRtComputationClient() { bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); std::string master_addr = runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", DistributedRuntime::default_coordinator_port); + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + // Use the XlaCoordinator as the distributed key-value store. + coordinator_ = std::make_unique( + global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = - DistributedRuntime::getInstance(global_process_rank, master_addr, port) - .GetClient(); + coordinator_->GetClient(); auto allowed_devices = std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; @@ -132,8 +137,6 @@ PjRtComputationClient::PjRtComputationClient() { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; client_ = std::move(xla::GetStreamExecutorGpuClient( @@ -185,6 +188,33 @@ PjRtComputationClient::PjRtComputationClient() { device_locks_.emplace(spmd_device_str, std::make_unique()); } +PjRtComputationClient::~PjRtComputationClient() { + // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient + // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. + client_ = nullptr; + coordinator_ = nullptr; +} + +bool PjRtComputationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void PjRtComputationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& PjRtComputationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; +} + void PjRtComputationClient::PjRtData::Assign( const torch::lazy::BackendData& data) { const PjRtData& pjrt_data = dynamic_cast(data); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index f4fc73bb79e5..faebd4892b87 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -23,6 +23,7 @@ namespace runtime { class PjRtComputationClient : public ComputationClient { public: PjRtComputationClient(); + ~PjRtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; @@ -89,6 +90,14 @@ class PjRtComputationClient : public ComputationClient { std::map GetMetrics() const override; + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + // NOT IMPLEMENTED MemoryInfo GetMemoryInfo(const std::string& device) override { @@ -97,6 +106,7 @@ class PjRtComputationClient : public ComputationClient { private: std::shared_ptr client_; + std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; diff --git a/torch_xla/csrc/runtime/distributed_runtime.cc b/torch_xla/csrc/runtime/xla_coordinator.cc similarity index 53% rename from torch_xla/csrc/runtime/distributed_runtime.cc rename to torch_xla/csrc/runtime/xla_coordinator.cc index dc3dbaf4eb49..72855d8681ea 100644 --- a/torch_xla/csrc/runtime/distributed_runtime.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -1,21 +1,18 @@ -#include "torch_xla/csrc/runtime/distributed_runtime.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/pjrt/distributed/distributed.h" namespace torch_xla { namespace runtime { -const std::string DistributedRuntime::default_coordinator_port = "8547"; - -DistributedRuntime::DistributedRuntime(int global_rank, std::string master_addr, - std::string port) { +XlaCoordinator::XlaCoordinator(int global_rank, int world_size, + std::string master_addr, std::string port) { std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":"); if (global_rank == 0) { - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); xla::CoordinationServiceImpl::Options service_options; - service_options.num_nodes = global_world_size; + service_options.num_nodes = world_size; xla::StatusOr> dist_runtime_service = xla::GetDistributedRuntimeService( dist_service_addr, service_options); @@ -32,7 +29,8 @@ DistributedRuntime::DistributedRuntime(int global_rank, std::string master_addr, << "Failed to initialize distributed runtime client"; } -DistributedRuntime::~DistributedRuntime() { +XlaCoordinator::~XlaCoordinator() { + preemption_sync_manager_ = nullptr; if (dist_runtime_client_ != nullptr) { XLA_CHECK(dist_runtime_client_->Shutdown().ok()) << "Failed to shut down the distributed runtime client."; @@ -44,11 +42,32 @@ DistributedRuntime::~DistributedRuntime() { } } -std::shared_ptr DistributedRuntime::GetClient() { +std::shared_ptr XlaCoordinator::GetClient() { XLA_CHECK(dist_runtime_client_ != nullptr) << "distributed runtime client is null."; return dist_runtime_client_; } +void XlaCoordinator::ActivatePreemptionSyncManager() { + if (preemption_sync_manager_ == nullptr) { + preemption_sync_manager_ = std::move(tsl::CreatePreemptionSyncManager()); + auto client = dist_runtime_client_->GetCoordinationServiceAgent(); + XLA_CHECK(client.ok()) << "Failed to retrieve the CoodinationServiceAgent"; + auto status = preemption_sync_manager_->Initialize(client.value()); + XLA_CHECK(status.ok()) << "Failed to initialize the PreemptionSyncManager"; + } +} + +void XlaCoordinator::DeactivatePreemptionSyncManager() { + preemption_sync_manager_ = nullptr; +} + +bool XlaCoordinator::ReachedSyncPoint(int step) { + XLA_CHECK(preemption_sync_manager_ != nullptr) + << "A PreemptionSyncManager has not been registered with the " + "XlaCoordinator."; + return preemption_sync_manager_->ReachedSyncPoint(step); +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h new file mode 100644 index 000000000000..ae85c79a9416 --- /dev/null +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -0,0 +1,53 @@ +#ifndef PTXLA_RUNTIME_COORDINATOR_H_ +#define PTXLA_RUNTIME_COORDINATOR_H_ + +#include + +#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/pjrt/distributed/distributed.h" + +namespace torch_xla { +namespace runtime { + +// XlaCoordinator serves as the point of entry for all operations which +// required the XLA distributed runtime, such as preemption coordination. +class XlaCoordinator { + public: + static inline const std::string kDefaultCoordinatorPort = "8547"; + + XlaCoordinator(int global_rank, int world_size, std::string master_addr, + std::string port); + + ~XlaCoordinator(); + + // Retrieve the DistributedRuntimeClient. + std::shared_ptr GetClient(); + + // Register a PreemptionSyncManager for the distributed runtime if none is + // active. The PreemptionSyncManager will register a SIGTERM handler, and + // when any host has received a preemption notice, all hosts are made aware + // through the ReachedSyncPoint API. See the documentation of + // tsl::PreemptionSyncManager for the full semantics: + // https://github.com/google/tsl/blob/3bbe663/tsl/distributed_runtime/preemption/preemption_sync_manager.h#L34 + void ActivatePreemptionSyncManager(); + + // If the PreemptionSyncManager is active, this will deactivate it and + // destroy the current instance. + void DeactivatePreemptionSyncManager(); + + // A pass-through API to PreemptionSyncManager::ReachedSyncPoint. + // The PreemptionSyncManager must be activated within the XlaCoordinator. + // Returns true when the input step has been identified as a sync point, and + // false otherwise. + bool ReachedSyncPoint(int step); + + private: + std::unique_ptr dist_runtime_service_; + std::shared_ptr dist_runtime_client_; + std::unique_ptr preemption_sync_manager_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif // PTXLA_RUNTIME_COORDINATOR_H_