diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 411bf8e27ca..ae130569545 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -188,6 +188,11 @@ PjRtComputationClient::PjRtComputationClient() { device_locks_.emplace(spmd_device_str, std::make_unique()); } +PjRtComputationClient::~PjRtComputationClient() { + client_ = nullptr; + coordinator_ = nullptr; +} + bool PjRtComputationClient::CoordinatorInitialized() const { return coordinator_ != nullptr; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 4ea3d96acef..6704491c41b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -23,6 +23,7 @@ namespace runtime { class PjRtComputationClient : public ComputationClient { public: PjRtComputationClient(); + ~PjRtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 8cfd0695184..ecb90f39167 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -1,5 +1,7 @@ #include +#include + #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" @@ -34,8 +36,10 @@ ComputationClient* CreateClient() { } // namespace ComputationClient* GetComputationClient() { - std::call_once(g_computation_client_once, - [&]() { g_computation_client = std::move(CreateClient()); }); + std::call_once(g_computation_client_once, [&]() { + g_computation_client = std::move(CreateClient()); + std::atexit([]() { delete g_computation_client.load(); }); + }); return g_computation_client.load(); } diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc index 606fe5cb470..72855d8681e 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -30,6 +30,7 @@ XlaCoordinator::XlaCoordinator(int global_rank, int world_size, } XlaCoordinator::~XlaCoordinator() { + preemption_sync_manager_ = nullptr; if (dist_runtime_client_ != nullptr) { XLA_CHECK(dist_runtime_client_->Shutdown().ok()) << "Failed to shut down the distributed runtime client.";