Skip to content

Commit

Permalink
Clean up logging macro issue handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 31, 2023
1 parent 278d2ef commit d880511
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
7 changes: 5 additions & 2 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,12 @@ def setUp(self):
super().setUp()
# Initialize the a minimal process group
dist.init_process_group(
backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0)
backend='gloo',
init_method='tcp://localhost:8932',
world_size=1,
rank=0)
torch_xla._XLAC._ensure_xla_coordinator_initialized(
global_rank=0, world_size=1, master_addr="127.1")
global_rank=0, world_size=1, master_addr="localhost")

def tearDown(self):
super().tearDown()
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
namespace torch_xla {
namespace runtime {

// Forward declaration
// Forward declare XlaCoordinator to avoid logging macro redefinition from the
// transitively included PJRT header.
// TODO(jonbolin): We need a way to ensure the right macros are included
// regardless of the import order.
class XlaCoordinator;

// Somehow the compiler doesn't allow type that has default member being
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
//#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -119,7 +119,8 @@ PjRtComputationClient::PjRtComputationClient() {
// Use the XlaCoordinator as the distributed key-value store.
coordinator_ = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
auto distributed_client = coordinator_->GetClient();
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator_->GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand Down Expand Up @@ -187,6 +188,11 @@ PjRtComputationClient::PjRtComputationClient() {
device_locks_.emplace(spmd_device_str, std::make_unique<std::shared_mutex>());
}

PjRtComputationClient::~PjRtComputationClient() {
coordinator_ = nullptr;
client_ = nullptr;
}

bool PjRtComputationClient::CoordinatorInitialized() const {
return coordinator_ != nullptr;
}
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/pjrt/pjrt_client.h"
Expand All @@ -24,6 +23,7 @@ namespace runtime {
class PjRtComputationClient : public ComputationClient {
public:
PjRtComputationClient();
~PjRtComputationClient();

DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override;

Expand Down Expand Up @@ -90,9 +90,9 @@ class PjRtComputationClient : public ComputationClient {

std::map<std::string, Metric> GetMetrics() const override;

void InitializeCoordinator(
int global_rank, int world_size, std::string master_addr,
std::string port = XlaCoordinator::kDefaultCoordinatorPort) override;
void InitializeCoordinator(int global_rank, int world_size,
std::string master_addr,
std::string port) override;

XlaCoordinator& GetCoordinator() override;

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/xla_coordinator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ XlaCoordinator::XlaCoordinator(int global_rank, int world_size,
}

XlaCoordinator::~XlaCoordinator() {
preemption_sync_manager_ = nullptr;
if (dist_runtime_client_ != nullptr) {
XLA_CHECK(dist_runtime_client_->Shutdown().ok())
<< "Failed to shut down the distributed runtime client.";
Expand Down
7 changes: 1 addition & 6 deletions torch_xla/csrc/runtime/xla_coordinator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
#include <memory>

#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h"

// Forward declaration
namespace xla {
class DistributedRuntimeClient;
class DistributedRuntimeService;
} // namespace xla
#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {
Expand Down

0 comments on commit d880511

Please sign in to comment.