Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PreemptionSyncManager in XlaCoordinator #5733

Merged
merged 3 commits into from
Nov 1, 2023
Merged

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Oct 25, 2023

To support autocheckpointing upon preemption, we need to access a PreemptionSyncManager to identify sync points when a preemption has occurred.

This change additionally refactors the DistributedRuntime to be owned by the ComputationClient, since in the GPU case the ComputationClient has a direct dependency on the DistributedRuntimeClient.

This change adds the PreemptionSyncManager to the new XlaCoordinator class. The PreemptionSyncManager has the side effect of registering a SIGTERM handler, so it is not enabled by default.

Copy link
Collaborator

@vanbasten23 vanbasten23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -50,5 +47,22 @@ std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() {
return dist_runtime_client_;
}

void DistributedRuntime::ActivatePreemptionSyncManager() {
if (preemption_sync_manager_ == nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any harm in initializing the PreemptionSyncManager when you initialize the xla::DistributedRuntimeService and Client? In general, I try to avoid cases where you "partially" construct an object and leave potential bugs to happen later (like calling ReachedSyncPoint before ActivatePreemptionSyncManager)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hesistant to do that, since it will register a SIGTERM handler which will cause any intentional SIGTERMs to be ignored. Open to revisiting, let me know which approach you think makes more sense!

torch_xla/csrc/runtime/distributed_runtime.h Outdated Show resolved Hide resolved
torch_xla/csrc/runtime/distributed_runtime.h Outdated Show resolved Hide resolved
#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {

// DistributedRuntime serves as the point of entry for all operations which
// required the XLA distributed runtime, such as preemption coordination.
class DistributedRuntime {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really dislike the naming choice for the upstream xla::DistributedRuntime, since it's not actually a distributed runtime. Since this class is becoming more than just a wrapper around xla::DistributedRuntimeService and xla::DistributedRuntimeClient, what do you think of changing the name to something more intuitive? e.g. XlaCoordinator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree, XlaCoordinator it is! I'll update the pybinds as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left as-is for now to keep this change minimal, we can revisit in the upcoming refactor.

// The PreemptionSyncManager must be activated within the DistributedRuntime.
// Returns true when the input step has been identified as a sync point, and
// false otherwise.
bool ReachedSyncPoint(int step);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it makes more sense to expose the tsl::PreemptionSyncManager directly as we do with the xla::DistributedRuntimeClient? Or do we want to restrict access to the underlying object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered that, but if the PreemptionSyncManager outlives the DistributedRuntimeClient, the program will segfault... 😢 I figured it's better to keep it hidden to avoid that edge case.

Copy link
Collaborator Author

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @vanbasten23 and @will-cromar! I'll update to address the feedback.

@@ -50,5 +47,22 @@ std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() {
return dist_runtime_client_;
}

void DistributedRuntime::ActivatePreemptionSyncManager() {
if (preemption_sync_manager_ == nullptr) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hesistant to do that, since it will register a SIGTERM handler which will cause any intentional SIGTERMs to be ignored. Open to revisiting, let me know which approach you think makes more sense!

torch_xla/csrc/runtime/distributed_runtime.h Outdated Show resolved Hide resolved
#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {

// DistributedRuntime serves as the point of entry for all operations which
// required the XLA distributed runtime, such as preemption coordination.
class DistributedRuntime {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree, XlaCoordinator it is! I'll update the pybinds as well.

// The PreemptionSyncManager must be activated within the DistributedRuntime.
// Returns true when the input step has been identified as a sync point, and
// false otherwise.
bool ReachedSyncPoint(int step);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered that, but if the PreemptionSyncManager outlives the DistributedRuntimeClient, the program will segfault... 😢 I figured it's better to keep it hidden to avoid that edge case.

@vanbasten23
Copy link
Collaborator

Oh one more thing, could you help check if we have a test verifying whether the distributed runtime service is always turned down every time? I'd imagine if we comment out the line dist_runtime_service_->Shutdown();, the test PJRT_DEVICE=GPU torchrun --nnodes 1 --nproc-per-node 2 pytorch/xla/test/pjrt/test_torchrun.py should time out.

@jonb377 jonb377 force-pushed the jonbolin/preemption branch 2 times, most recently from a5c5e74 to 10dc8db Compare October 27, 2023 16:18
@jonb377
Copy link
Collaborator Author

jonb377 commented Oct 27, 2023

@vanbasten23 @will-cromar I've updated to have the ComputationClient own the XlaCoordinator. Please take a second look when you get a chance!

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

FYI, our style guide cautions against forward declarations of entities in other projects, even if it saves compile time: https://google.github.io/styleguide/cppguide.html#Forward_Declarations

If you forward-declared the DistributedRuntime classes to unravel e.g. a macro conflict or circular dependency, please leave a comment explaining why.

torch_xla/csrc/runtime/computation_client.h Show resolved Hide resolved
torch_xla/csrc/runtime/pjrt_computation_client.cc Outdated Show resolved Hide resolved
torch_xla/csrc/runtime/pjrt_computation_client.cc Outdated Show resolved Hide resolved
torch_xla/csrc/runtime/xla_coordinator.h Outdated Show resolved Hide resolved
@@ -350,6 +353,17 @@ class ComputationClient {
// the local devices will be waited for.
virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0;

// Check whether the XlaCoordinator has been initialized.
virtual bool CoordinatorInitialized() const = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be virtual? It looks like the implementations below don't depend on the underlying runtime client.

Copy link
Collaborator Author

@jonb377 jonb377 Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The XlaCoordinator depends on PJRT, so I kept it separate. Though I guess that's not a strong justification...

@jonb377 jonb377 changed the title Support PreemptionSyncManager in DistributedRuntime Support PreemptionSyncManager in XlaCoordinator Oct 28, 2023
@jonb377 jonb377 force-pushed the jonbolin/preemption branch 4 times, most recently from b954fed to 9aafdf5 Compare October 28, 2023 17:42
@@ -30,6 +30,12 @@
namespace torch_xla {
namespace runtime {

// Forward declare XlaCoordinator to avoid logging macro redefinition from the
// transitively included PJRT header.
// TODO(jonbolin): We need a way to ensure the right macros are included
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging macros are cursed ☹️

@jonb377 jonb377 force-pushed the jonbolin/preemption branch 2 times, most recently from d2a36a6 to 673cad9 Compare October 30, 2023 18:53
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I had a hard time of imagining how this is going to be incorporated to the ckpt mgr. @jonb377 Can you point me to some examples?

@jonb377
Copy link
Collaborator Author

jonb377 commented Oct 30, 2023

I guess I had a hard time of imagining how this is going to be incorporated to the ckpt mgr. @jonb377 Can you point me to some examples?

@alanwaketan CheckpointManager will initialize the PreemptionSyncManager on construction and call into _sync_point_reached each step to check for preemption in should_save - I'll open a draft PR to illustrate.

@jonb377
Copy link
Collaborator Author

jonb377 commented Oct 30, 2023

@alanwaketan See 5fdce13 for the intended usage.

@alanwaketan
Copy link
Collaborator

I guess I had a hard time of imagining how this is going to be incorporated to the ckpt mgr.

I guess I had a hard time of imagining how this is going to be incorporated to the ckpt mgr. @jonb377 Can you point me to some examples?

@alanwaketan CheckpointManager will initialize the PreemptionSyncManager on construction and call into _sync_point_reached each step to check for preemption in should_save - I'll open a draft PR to illustrate.

I see, that makes sense.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jonb377 jonb377 force-pushed the jonbolin/preemption branch 3 times, most recently from 0d35028 to cfbb93a Compare October 30, 2023 21:56
@jonb377 jonb377 changed the title Support PreemptionSyncManager in XlaCoordinator Support PreemptionSyncManager in DistributedRuntime Oct 30, 2023
@jonb377 jonb377 changed the title Support PreemptionSyncManager in DistributedRuntime Support PreemptionSyncManager in XlaCoordinator Oct 31, 2023
@jonb377 jonb377 merged commit b20a082 into master Nov 1, 2023
19 checks passed
@jonb377 jonb377 deleted the jonbolin/preemption branch November 1, 2023 16:50
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Support PreemptionSyncManager in DistributedRuntime

* Refactor to be owned by ComputationClient

* Clean up logging macro issue handling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants