Skip to content

Commit

Permalink
Refactor to be owned by ComputationClient
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 27, 2023
1 parent 54961b7 commit 10dc8db
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 127 deletions.
8 changes: 4 additions & 4 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ def setUp(self):
# Initialize the a minimal process group
dist.init_process_group(
backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0)
torch_xla._XLAC._ensure_xla_coordinator_initialized(
global_rank=0, world_size=1, master_addr="127.1")

def tearDown(self):
super().tearDown()
Expand Down Expand Up @@ -482,8 +484,6 @@ def test_manager_async_step_tracking(self, tmpdir):

def test_preemption_sync_manager(self):
try:
torch_xla._XLAC._ensure_dist_runtime_initialized(
global_rank=0, world_size=1, master_addr="127.1")
torch_xla._XLAC._activate_preemption_sync_manager()
sync_point_reached = torch_xla._XLAC._sync_point_reached

Expand All @@ -507,8 +507,8 @@ def test_preemption_sync_manager(self):
time.sleep(1)
self.assertTrue(success, "Sync point was never reached after SIGTERM")
finally:
# Scope the distributed runtime to the lifespan of the test.
torch_xla._XLAC._ensure_dist_runtime_shutdown()
# Scope the PreemptionSyncManager to the lifespan of the test.
torch_xla._XLAC._deactivate_preemption_sync_manager()


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc/runtime:util",
"//torch_xla/csrc/runtime:xla_coordinator",
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
Expand Down
48 changes: 27 additions & 21 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/distributed_runtime.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/metrics_analysis.h"
#include "torch_xla/csrc/runtime/metrics_reader.h"
Expand All @@ -48,6 +47,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_impl.h"
Expand Down Expand Up @@ -1878,38 +1878,44 @@ void InitXlaModuleBindings(py::module m) {
xla::HloModule::CreateFromProto(module_proto, config).value());
return module->ToString();
});
// Initialize a distributed runtime if one has not already been created.
m.def("_ensure_dist_runtime_initialized",
// Initialize the XlaCoordinator in the runtime if not already initialized.
m.def("_ensure_xla_coordinator_initialized",
[](int global_rank, int world_size, std::string master_addr,
std::string master_port) {
if (!runtime::DistributedRuntime::IsInitialized()) {
runtime::DistributedRuntime::Initialize(global_rank, world_size,
master_addr, master_port);
auto comp_client = runtime::GetComputationClient();
if (!comp_client->CoordinatorInitialized()) {
runtime::GetComputationClient()->InitializeCoordinator(
global_rank, world_size, master_addr, master_port);
}
},
py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"),
py::arg("master_port") =
runtime::DistributedRuntime::kDefaultCoordinatorPort);
// Shutdown the distributed runtime if it's active.
m.def("_ensure_dist_runtime_shutdown", []() {
if (runtime::DistributedRuntime::IsInitialized()) {
runtime::DistributedRuntime::Shutdown();
}
});
// Create a PreemptionSyncManager for the DistributedRuntime. The
runtime::XlaCoordinator::kDefaultCoordinatorPort);
// Create a PreemptionSyncManager for the XlaCoordinator. The
// PreemptionSyncManager will register a SIGTERM handler as a side effect.
m.def("_activate_preemption_sync_manager", []() {
XLA_CHECK(runtime::DistributedRuntime::IsInitialized())
<< "DistributedRuntime must be initialized to register "
"PreemptionSyncManager";
runtime::DistributedRuntime::Get().ActivatePreemptionSyncManager();
auto comp_client = runtime::GetComputationClient();
XLA_CHECK(comp_client->CoordinatorInitialized())
<< "Coordinator must be initialized";
auto& coordinator = comp_client->GetCoordinator();
coordinator.ActivatePreemptionSyncManager();
});
// Deactivate the PreemptionSyncManager in the XlaCoordinator if one is active
m.def("_deactivate_preemption_sync_manager", []() {
auto comp_client = runtime::GetComputationClient();
XLA_CHECK(comp_client->CoordinatorInitialized())
<< "Coordinator must be initialized";
auto& coordinator = comp_client->GetCoordinator();
coordinator.DeactivatePreemptionSyncManager();
});
// Check whether a sync point has been reached. This method requires that the
// distributed runtime be initialized and a PreemptionSyncManager activated.
m.def("_sync_point_reached", [](int step) {
XLA_CHECK(runtime::DistributedRuntime::IsInitialized())
<< "DistributedRuntime must be initialized";
return runtime::DistributedRuntime::Get().ReachedSyncPoint(step);
auto comp_client = runtime::GetComputationClient();
XLA_CHECK(comp_client->CoordinatorInitialized())
<< "Coordinator must be initialized";
auto& coordinator = comp_client->GetCoordinator();
return coordinator.ReachedSyncPoint(step);
});
m.def("_is_placecholder", [](at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cc_library(
":sys_util",
":types",
":util",
":xla_coordinator",
"//torch_xla/csrc:device",
"@tsl//tsl/platform:stacktrace_handler",
"@xla//xla:literal_util",
Expand All @@ -88,12 +89,12 @@ cc_library(
deps = [
":computation_client",
":debug_macros",
":distributed_runtime",
":env_vars",
":multi_wait",
":stablehlo_helper",
":tf_logging",
":thread_pool",
":xla_coordinator",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
Expand Down Expand Up @@ -165,9 +166,9 @@ cc_library(
)

cc_library(
name = "distributed_runtime",
srcs = ["distributed_runtime.cc"],
hdrs = ["distributed_runtime.h"],
name = "xla_coordinator",
srcs = ["xla_coordinator.cc"],
hdrs = ["xla_coordinator.h"],
deps = [
":debug_macros",
":sys_util",
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
namespace torch_xla {
namespace runtime {

// Forward declaration
class XlaCoordinator;

// Somehow the compiler doesn't allow type that has default member being
// used as a default parameter in a method defined in the same scope.
// Therefore, ClientExecuteOptions is defined here instead of within
Expand Down Expand Up @@ -350,6 +353,17 @@ class ComputationClient {
// the local devices will be waited for.
virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0;

// Check whether the XlaCoordinator has been initialized.
virtual bool CoordinatorInitialized() const = 0;

// Initialize the XlaCoordinator for the runtime.
virtual void InitializeCoordinator(int global_rank, int world_size,
std::string master_addr,
std::string port) = 0;

// Return the XlaCoordinator for the runtime.
virtual XlaCoordinator& GetCoordinator() = 0;

// Utility API around the vector based Compile() API to compile a single
// computation.
ComputationPtr Compile(xla::XlaComputation computation,
Expand Down
83 changes: 0 additions & 83 deletions torch_xla/csrc/runtime/distributed_runtime.h

This file was deleted.

32 changes: 26 additions & 6 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/distributed_runtime.h"
//#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
Expand Down Expand Up @@ -114,12 +114,12 @@ PjRtComputationClient::PjRtComputationClient() {
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", DistributedRuntime::kDefaultCoordinatorPort);
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

// Use the DistributedRuntime as the distributed key-value store.
DistributedRuntime::Initialize(global_process_rank, global_world_size,
master_addr, port);
auto distributed_client = DistributedRuntime::Get().GetClient();
// Use the XlaCoordinator as the distributed key-value store.
coordinator_ = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
auto distributed_client = coordinator_->GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand Down Expand Up @@ -187,6 +187,26 @@ PjRtComputationClient::PjRtComputationClient() {
device_locks_.emplace(spmd_device_str, std::make_unique<std::shared_mutex>());
}

bool PjRtComputationClient::CoordinatorInitialized() const {
return coordinator_ != nullptr;
}

void PjRtComputationClient::InitializeCoordinator(int global_rank,
int world_size,
std::string master_addr,
std::string port) {
XLA_CHECK(coordinator_ == nullptr)
<< "Can only initialize the XlaCoordinator once.";
coordinator_ = std::make_unique<XlaCoordinator>(global_rank, world_size,
master_addr, port);
}

XlaCoordinator& PjRtComputationClient::GetCoordinator() {
XLA_CHECK(coordinator_ != nullptr)
<< "XlaCoordinator has not been initialized";
return *coordinator_;
}

void PjRtComputationClient::PjRtData::Assign(
const torch::lazy::BackendData& data) {
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(data);
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/pjrt/pjrt_client.h"
Expand Down Expand Up @@ -91,6 +92,14 @@ class PjRtComputationClient : public ComputationClient {

std::map<std::string, Metric> GetMetrics() const override;

void InitializeCoordinator(
int global_rank, int world_size, std::string master_addr,
std::string port = XlaCoordinator::kDefaultCoordinatorPort) override;

XlaCoordinator& GetCoordinator() override;

bool CoordinatorInitialized() const override;

// NOT IMPLEMENTED

MemoryInfo GetMemoryInfo(const std::string& device) override {
Expand All @@ -99,6 +108,7 @@ class PjRtComputationClient : public ComputationClient {

private:
std::shared_ptr<xla::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
// global_ordinals_ tracks a map from PjRtDeviceId to the device's
// dense global ordinal.
std::unordered_map<int, int> global_ordinals_;
Expand Down
Loading

0 comments on commit 10dc8db

Please sign in to comment.