Skip to content

Commit

Permalink
Ensure ComputationClient is shut down if initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 28, 2023
1 parent 1502aa8 commit 9aafdf5
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 2 deletions.
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ PjRtComputationClient::PjRtComputationClient() {
device_locks_.emplace(spmd_device_str, std::make_unique<std::shared_mutex>());
}

PjRtComputationClient::~PjRtComputationClient() {
client_ = nullptr;
coordinator_ = nullptr;
}

bool PjRtComputationClient::CoordinatorInitialized() const {
return coordinator_ != nullptr;
}
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace runtime {
class PjRtComputationClient : public ComputationClient {
public:
PjRtComputationClient();
~PjRtComputationClient();

DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override;

Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <torch/csrc/lazy/backend/backend_device.h>

#include <cstdlib>

#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/env_vars.h"
Expand Down Expand Up @@ -34,8 +36,10 @@ ComputationClient* CreateClient() {
} // namespace

ComputationClient* GetComputationClient() {
std::call_once(g_computation_client_once,
[&]() { g_computation_client = std::move(CreateClient()); });
std::call_once(g_computation_client_once, [&]() {
g_computation_client = std::move(CreateClient());
std::atexit([]() { delete g_computation_client.load(); });
});
return g_computation_client.load();
}

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/xla_coordinator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ XlaCoordinator::XlaCoordinator(int global_rank, int world_size,
}

XlaCoordinator::~XlaCoordinator() {
preemption_sync_manager_ = nullptr;
if (dist_runtime_client_ != nullptr) {
XLA_CHECK(dist_runtime_client_->Shutdown().ok())
<< "Failed to shut down the distributed runtime client.";
Expand Down

0 comments on commit 9aafdf5

Please sign in to comment.