Skip to content

Commit

Permalink
Clean up logging macro issue handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 27, 2023
1 parent 10dc8db commit f20d522
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 15 deletions.
4 changes: 2 additions & 2 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ def setUp(self):
super().setUp()
# Initialize the a minimal process group
dist.init_process_group(
backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0)
backend='gloo', init_method='tcp://localhost:8932', world_size=1, rank=0)
torch_xla._XLAC._ensure_xla_coordinator_initialized(
global_rank=0, world_size=1, master_addr="127.1")
global_rank=0, world_size=1, master_addr="localhost")

def tearDown(self):
super().tearDown()
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
namespace torch_xla {
namespace runtime {

// Forward declaration
// Forward declare XlaCoordinator to avoid logging macro redefinition from the
// transitively included PJRT header.
// TODO(jonbolin): We need a way to ensure the right macros are included
// regardless of the import order.
class XlaCoordinator;

// Somehow the compiler doesn't allow type that has default member being
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
//#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -119,7 +119,7 @@ PjRtComputationClient::PjRtComputationClient() {
// Use the XlaCoordinator as the distributed key-value store.
coordinator_ = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
auto distributed_client = coordinator_->GetClient();
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client = coordinator_->GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand Down
7 changes: 3 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/pjrt/pjrt_client.h"
Expand Down Expand Up @@ -92,9 +91,9 @@ class PjRtComputationClient : public ComputationClient {

std::map<std::string, Metric> GetMetrics() const override;

void InitializeCoordinator(
int global_rank, int world_size, std::string master_addr,
std::string port = XlaCoordinator::kDefaultCoordinatorPort) override;
void InitializeCoordinator(int global_rank, int world_size,
std::string master_addr,
std::string port) override;

XlaCoordinator& GetCoordinator() override;

Expand Down
7 changes: 1 addition & 6 deletions torch_xla/csrc/runtime/xla_coordinator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
#include <memory>

#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h"

// Forward declaration
namespace xla {
class DistributedRuntimeClient;
class DistributedRuntimeService;
} // namespace xla
#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {
Expand Down

0 comments on commit f20d522

Please sign in to comment.