From 9aafdf5836658da7fc25109d3ef5c32fe0835491 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Fri, 27 Oct 2023 22:18:17 +0000 Subject: [PATCH] Ensure ComputationClient is shut down if initialized --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 5 +++++ torch_xla/csrc/runtime/pjrt_computation_client.h | 1 + torch_xla/csrc/runtime/runtime.cc | 8 ++++++-- torch_xla/csrc/runtime/xla_coordinator.cc | 1 + 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 411bf8e27cac..ae1305695452 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -188,6 +188,11 @@ PjRtComputationClient::PjRtComputationClient() { device_locks_.emplace(spmd_device_str, std::make_unique()); } +PjRtComputationClient::~PjRtComputationClient() { + client_ = nullptr; + coordinator_ = nullptr; +} + bool PjRtComputationClient::CoordinatorInitialized() const { return coordinator_ != nullptr; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 4ea3d96acefd..6704491c41b4 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -23,6 +23,7 @@ namespace runtime { class PjRtComputationClient : public ComputationClient { public: PjRtComputationClient(); + ~PjRtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 8cfd06951842..ecb90f391670 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -1,5 +1,7 @@ #include +#include + #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" @@ -34,8 +36,10 @@ ComputationClient* CreateClient() { } // namespace ComputationClient* GetComputationClient() { - std::call_once(g_computation_client_once, - [&]() { g_computation_client = std::move(CreateClient()); }); + std::call_once(g_computation_client_once, [&]() { + g_computation_client = std::move(CreateClient()); + std::atexit([]() { delete g_computation_client.load(); }); + }); return g_computation_client.load(); } diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc index 606fe5cb470a..72855d8681ea 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -30,6 +30,7 @@ XlaCoordinator::XlaCoordinator(int global_rank, int world_size, } XlaCoordinator::~XlaCoordinator() { + preemption_sync_manager_ = nullptr; if (dist_runtime_client_ != nullptr) { XLA_CHECK(dist_runtime_client_->Shutdown().ok()) << "Failed to shut down the distributed runtime client.";