Skip to content

Commit

Permalink
Delete ComputationClient in PrepareToExit
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 30, 2023
1 parent e1c4244 commit d2a36a6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void PrepareToExit() {
if (client != nullptr) {
XLAGraphExecutor::Get()->WaitDeviceOps({});
client->PrepareToExit();
runtime::DeleteComputationClient();
}
}

Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ PjRtComputationClient::PjRtComputationClient() {
}

PjRtComputationClient::~PjRtComputationClient() {
// In the GPU case, the PjRtClient depends on the DistributedRuntimeClient
// tracked in XlaCoordinator, so the PjRtClient must be destroyed first.
client_ = nullptr;
coordinator_ = nullptr;
}
Expand Down
16 changes: 10 additions & 6 deletions torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include <torch/csrc/lazy/backend/backend_device.h>

#include <cstdlib>

#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/env_vars.h"
Expand Down Expand Up @@ -35,11 +33,17 @@ ComputationClient* CreateClient() {

} // namespace

void DeleteComputationClient() {
ComputationClient *client = g_computation_client.load();
if (client != nullptr) {
g_computation_client = nullptr;
delete client;
}
}

ComputationClient* GetComputationClient() {
std::call_once(g_computation_client_once, [&]() {
g_computation_client = std::move(CreateClient());
std::atexit([]() { delete g_computation_client.load(); });
});
std::call_once(g_computation_client_once,
[&]() { g_computation_client = std::move(CreateClient()); });
return g_computation_client.load();
}

Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ ComputationClient* GetComputationClient();

ComputationClient* GetComputationClientIfInitialized();

// Delete the ComputationClient. This should be called when the program is about
// to exit.
void DeleteComputationClient();

// Run the XRT local service, this will block the caller unitl the server
// being stopped.
void RunLocalService(uint64_t service_port);
Expand Down

0 comments on commit d2a36a6

Please sign in to comment.