From f20d522730e914318cb4118f067e197346831248 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Fri, 27 Oct 2023 20:26:27 +0000 Subject: [PATCH] Clean up logging macro issue handling --- test/spmd/test_xla_distributed_checkpoint.py | 4 ++-- torch_xla/csrc/runtime/computation_client.h | 5 ++++- torch_xla/csrc/runtime/pjrt_computation_client.cc | 4 ++-- torch_xla/csrc/runtime/pjrt_computation_client.h | 7 +++---- torch_xla/csrc/runtime/xla_coordinator.h | 7 +------ 5 files changed, 12 insertions(+), 15 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 9104de97b75f..bb97fc041098 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -342,9 +342,9 @@ def setUp(self): super().setUp() # Initialize the a minimal process group dist.init_process_group( - backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0) + backend='gloo', init_method='tcp://localhost:8932', world_size=1, rank=0) torch_xla._XLAC._ensure_xla_coordinator_initialized( - global_rank=0, world_size=1, master_addr="127.1") + global_rank=0, world_size=1, master_addr="localhost") def tearDown(self): super().tearDown() diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 1d09efedc2e4..38cb1deaf7ca 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -30,7 +30,10 @@ namespace torch_xla { namespace runtime { -// Forward declaration +// Forward declare XlaCoordinator to avoid logging macro redefinition from the +// transitively included PJRT header. +// TODO(jonbolin): We need a way to ensure the right macros are included +// regardless of the import order. class XlaCoordinator; // Somehow the compiler doesn't allow type that has default member being diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 5f49fd2e3570..c07265317c0c 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -9,12 +9,12 @@ #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -//#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" @@ -119,7 +119,7 @@ PjRtComputationClient::PjRtComputationClient() { // Use the XlaCoordinator as the distributed key-value store. coordinator_ = std::make_unique( global_process_rank, global_world_size, master_addr, port); - auto distributed_client = coordinator_->GetClient(); + std::shared_ptr distributed_client = coordinator_->GetClient(); auto allowed_devices = std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index aee98fc7d3a1..4ea3d96acefd 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -11,7 +11,6 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/util.h" -#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -92,9 +91,9 @@ class PjRtComputationClient : public ComputationClient { std::map GetMetrics() const override; - void InitializeCoordinator( - int global_rank, int world_size, std::string master_addr, - std::string port = XlaCoordinator::kDefaultCoordinatorPort) override; + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; XlaCoordinator& GetCoordinator() override; diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h index 88cc3e752dd9..ae85c79a9416 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.h +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -4,12 +4,7 @@ #include #include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" - -// Forward declaration -namespace xla { -class DistributedRuntimeClient; -class DistributedRuntimeService; -} // namespace xla +#include "xla/pjrt/distributed/distributed.h" namespace torch_xla { namespace runtime {