Skip to content

Commit

Permalink
Support PreemptionSyncManager in XlaCoordinator (#5733)
Browse files Browse the repository at this point in the history
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
  • Loading branch information
jonb377 authored and ManfeiBai committed Nov 29, 2023
1 parent 885f15f commit 244efa7
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 60 deletions.
40 changes: 38 additions & 2 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import functools
import os
import signal
import sys
import tempfile
import unittest
import test_xla_sharding_base
import threading
import time
import unittest

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
Expand Down Expand Up @@ -339,7 +342,12 @@ def setUp(self):
super().setUp()
# Initialize the a minimal process group
dist.init_process_group(
backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0)
backend='gloo',
init_method='tcp://localhost:8932',
world_size=1,
rank=0)
torch_xla._XLAC._ensure_xla_coordinator_initialized(
global_rank=0, world_size=1, master_addr="localhost")

def tearDown(self):
super().tearDown()
Expand Down Expand Up @@ -486,6 +494,34 @@ def test_master_ip_discovery(self, patched_get_worker_ips):
patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2']
self.assertTrue(xr.get_master_ip(), '10.0.0.1')

def test_preemption_sync_manager(self):
try:
torch_xla._XLAC._activate_preemption_sync_manager()
sync_point_reached = torch_xla._XLAC._sync_point_reached

# No sync point for the first several steps
sigterm_step = 10
for step in range(sigterm_step):
self.assertFalse(sync_point_reached(step))

# Send a SIGTERM to the current process to trigger a sync point
os.kill(os.getpid(), signal.SIGTERM)

# Allow the signal to be processed. The PreemptionSyncManager must receive
# the SIGTERM, which happens asynchronously, and the state must be
# propagated through the distributed runtime. Eventually,
# sync_point_reached will return True.
success = False
for attempt in range(10):
success = sync_point_reached(sigterm_step + attempt)
if success:
break
time.sleep(1)
self.assertTrue(success, "Sync point was never reached after SIGTERM")
finally:
# Scope the PreemptionSyncManager to the lifespan of the test.
torch_xla._XLAC._deactivate_preemption_sync_manager()


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ ptxla_cc_library(
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:thread_pool",
"//torch_xla/csrc/runtime:util",
"//torch_xla/csrc/runtime:xla_coordinator",
"//torch_xla/csrc/runtime:xla_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
Expand Down
40 changes: 40 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_impl.h"
Expand Down Expand Up @@ -1876,6 +1877,45 @@ void InitXlaModuleBindings(py::module m) {
xla::HloModule::CreateFromProto(module_proto, config).value());
return module->ToString();
});
// Initialize the XlaCoordinator in the runtime if not already initialized.
m.def("_ensure_xla_coordinator_initialized",
[](int global_rank, int world_size, std::string master_addr,
std::string master_port) {
auto comp_client = runtime::GetComputationClient();
if (!comp_client->CoordinatorInitialized()) {
runtime::GetComputationClient()->InitializeCoordinator(
global_rank, world_size, master_addr, master_port);
}
},
py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"),
py::arg("master_port") =
runtime::XlaCoordinator::kDefaultCoordinatorPort);
// Create a PreemptionSyncManager for the XlaCoordinator. The
// PreemptionSyncManager will register a SIGTERM handler as a side effect.
m.def("_activate_preemption_sync_manager", []() {
auto comp_client = runtime::GetComputationClient();
XLA_CHECK(comp_client->CoordinatorInitialized())
<< "Coordinator must be initialized";
auto& coordinator = comp_client->GetCoordinator();
coordinator.ActivatePreemptionSyncManager();
});
// Deactivate the PreemptionSyncManager in the XlaCoordinator if one is active
m.def("_deactivate_preemption_sync_manager", []() {
auto comp_client = runtime::GetComputationClient();
XLA_CHECK(comp_client->CoordinatorInitialized())
<< "Coordinator must be initialized";
auto& coordinator = comp_client->GetCoordinator();
coordinator.DeactivatePreemptionSyncManager();
});
// Check whether a sync point has been reached. This method requires that the
// distributed runtime be initialized and a PreemptionSyncManager activated.
m.def("_sync_point_reached", [](int step) {
auto comp_client = runtime::GetComputationClient();
XLA_CHECK(comp_client->CoordinatorInitialized())
<< "Coordinator must be initialized";
auto& coordinator = comp_client->GetCoordinator();
return coordinator.ReachedSyncPoint(step);
});
m.def("_is_placecholder", [](at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->CurrentDataHandle() &&
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cc_library(
":sys_util",
":types",
":util",
":xla_coordinator",
"//torch_xla/csrc:device",
"@tsl//tsl/platform:stacktrace_handler",
"@xla//xla:literal_util",
Expand All @@ -88,12 +89,12 @@ cc_library(
deps = [
":computation_client",
":debug_macros",
":distributed_runtime",
":env_vars",
":multi_wait",
":stablehlo_helper",
":tf_logging",
":thread_pool",
":xla_coordinator",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
Expand Down Expand Up @@ -165,13 +166,14 @@ cc_library(
)

cc_library(
name = "distributed_runtime",
srcs = ["distributed_runtime.cc"],
hdrs = ["distributed_runtime.h"],
name = "xla_coordinator",
srcs = ["xla_coordinator.cc"],
hdrs = ["xla_coordinator.h"],
deps = [
":debug_macros",
":sys_util",
"@xla//xla/pjrt/distributed",
"@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager",
],
)

Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
namespace torch_xla {
namespace runtime {

// Forward declare XlaCoordinator to avoid logging macro redefinition from the
// transitively included PJRT header.
// TODO(jonbolin): We need a way to ensure the right macros are included
// regardless of the import order.
class XlaCoordinator;

// Somehow the compiler doesn't allow type that has default member being
// used as a default parameter in a method defined in the same scope.
// Therefore, ClientExecuteOptions is defined here instead of within
Expand Down Expand Up @@ -348,6 +354,17 @@ class ComputationClient {
// the local devices will be waited for.
virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0;

// Check whether the XlaCoordinator has been initialized.
virtual bool CoordinatorInitialized() const = 0;

// Initialize the XlaCoordinator for the runtime.
virtual void InitializeCoordinator(int global_rank, int world_size,
std::string master_addr,
std::string port) = 0;

// Return the XlaCoordinator for the runtime.
virtual XlaCoordinator& GetCoordinator() = 0;

// Utility API around the vector based Compile() API to compile a single
// computation.
ComputationPtr Compile(xla::XlaComputation computation,
Expand Down
38 changes: 0 additions & 38 deletions torch_xla/csrc/runtime/distributed_runtime.h

This file was deleted.

42 changes: 36 additions & 6 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/distributed_runtime.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -109,13 +109,18 @@ PjRtComputationClient::PjRtComputationClient() {
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true);
int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0);
int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank);
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", DistributedRuntime::default_coordinator_port);
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

// Use the XlaCoordinator as the distributed key-value store.
coordinator_ = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
DistributedRuntime::getInstance(global_process_rank, master_addr, port)
.GetClient();
coordinator_->GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand All @@ -132,8 +137,6 @@ PjRtComputationClient::PjRtComputationClient() {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
Expand Down Expand Up @@ -185,6 +188,33 @@ PjRtComputationClient::PjRtComputationClient() {
device_locks_.emplace(spmd_device_str, std::make_unique<std::shared_mutex>());
}

PjRtComputationClient::~PjRtComputationClient() {
// In the GPU case, the PjRtClient depends on the DistributedRuntimeClient
// tracked in XlaCoordinator, so the PjRtClient must be destroyed first.
client_ = nullptr;
coordinator_ = nullptr;
}

bool PjRtComputationClient::CoordinatorInitialized() const {
return coordinator_ != nullptr;
}

void PjRtComputationClient::InitializeCoordinator(int global_rank,
int world_size,
std::string master_addr,
std::string port) {
XLA_CHECK(coordinator_ == nullptr)
<< "Can only initialize the XlaCoordinator once.";
coordinator_ = std::make_unique<XlaCoordinator>(global_rank, world_size,
master_addr, port);
}

XlaCoordinator& PjRtComputationClient::GetCoordinator() {
XLA_CHECK(coordinator_ != nullptr)
<< "XlaCoordinator has not been initialized";
return *coordinator_;
}

void PjRtComputationClient::PjRtData::Assign(
const torch::lazy::BackendData& data) {
const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(data);
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace runtime {
class PjRtComputationClient : public ComputationClient {
public:
PjRtComputationClient();
~PjRtComputationClient();

DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override;

Expand Down Expand Up @@ -89,6 +90,14 @@ class PjRtComputationClient : public ComputationClient {

std::map<std::string, Metric> GetMetrics() const override;

void InitializeCoordinator(int global_rank, int world_size,
std::string master_addr,
std::string port) override;

XlaCoordinator& GetCoordinator() override;

bool CoordinatorInitialized() const override;

// NOT IMPLEMENTED

MemoryInfo GetMemoryInfo(const std::string& device) override {
Expand All @@ -97,6 +106,7 @@ class PjRtComputationClient : public ComputationClient {

private:
std::shared_ptr<xla::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
// global_ordinals_ tracks a map from PjRtDeviceId to the device's
// dense global ordinal.
std::unordered_map<int, int> global_ordinals_;
Expand Down
Loading

0 comments on commit 244efa7

Please sign in to comment.