-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PreemptionSyncManager in XlaCoordinator #5733
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -50,5 +47,22 @@ std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() { | |||
return dist_runtime_client_; | |||
} | |||
|
|||
void DistributedRuntime::ActivatePreemptionSyncManager() { | |||
if (preemption_sync_manager_ == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any harm in initializing the PreemptionSyncManager
when you initialize the xla::DistributedRuntimeService
and Client
? In general, I try to avoid cases where you "partially" construct an object and leave potential bugs to happen later (like calling ReachedSyncPoint
before ActivatePreemptionSyncManager
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hesistant to do that, since it will register a SIGTERM handler which will cause any intentional SIGTERMs to be ignored. Open to revisiting, let me know which approach you think makes more sense!
#include "xla/pjrt/distributed/distributed.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
// DistributedRuntime serves as the point of entry for all operations which | ||
// required the XLA distributed runtime, such as preemption coordination. | ||
class DistributedRuntime { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really dislike the naming choice for the upstream xla::DistributedRuntime
, since it's not actually a distributed runtime. Since this class is becoming more than just a wrapper around xla::DistributedRuntimeService
and xla::DistributedRuntimeClient
, what do you think of changing the name to something more intuitive? e.g. XlaCoordinator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree, XlaCoordinator it is! I'll update the pybinds as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left as-is for now to keep this change minimal, we can revisit in the upcoming refactor.
// The PreemptionSyncManager must be activated within the DistributedRuntime. | ||
// Returns true when the input step has been identified as a sync point, and | ||
// false otherwise. | ||
bool ReachedSyncPoint(int step); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it makes more sense to expose the tsl::PreemptionSyncManager
directly as we do with the xla::DistributedRuntimeClient
? Or do we want to restrict access to the underlying object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered that, but if the PreemptionSyncManager outlives the DistributedRuntimeClient, the program will segfault... 😢 I figured it's better to keep it hidden to avoid that edge case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @vanbasten23 and @will-cromar! I'll update to address the feedback.
@@ -50,5 +47,22 @@ std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() { | |||
return dist_runtime_client_; | |||
} | |||
|
|||
void DistributedRuntime::ActivatePreemptionSyncManager() { | |||
if (preemption_sync_manager_ == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hesistant to do that, since it will register a SIGTERM handler which will cause any intentional SIGTERMs to be ignored. Open to revisiting, let me know which approach you think makes more sense!
#include "xla/pjrt/distributed/distributed.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
// DistributedRuntime serves as the point of entry for all operations which | ||
// required the XLA distributed runtime, such as preemption coordination. | ||
class DistributedRuntime { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree, XlaCoordinator it is! I'll update the pybinds as well.
// The PreemptionSyncManager must be activated within the DistributedRuntime. | ||
// Returns true when the input step has been identified as a sync point, and | ||
// false otherwise. | ||
bool ReachedSyncPoint(int step); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered that, but if the PreemptionSyncManager outlives the DistributedRuntimeClient, the program will segfault... 😢 I figured it's better to keep it hidden to avoid that edge case.
Oh one more thing, could you help check if we have a test verifying whether the distributed runtime service is always turned down every time? I'd imagine if we comment out the line |
a5c5e74
to
10dc8db
Compare
@vanbasten23 @will-cromar I've updated to have the ComputationClient own the XlaCoordinator. Please take a second look when you get a chance! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
FYI, our style guide cautions against forward declarations of entities in other projects, even if it saves compile time: https://google.github.io/styleguide/cppguide.html#Forward_Declarations
If you forward-declared the DistributedRuntime
classes to unravel e.g. a macro conflict or circular dependency, please leave a comment explaining why.
@@ -350,6 +353,17 @@ class ComputationClient { | |||
// the local devices will be waited for. | |||
virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0; | |||
|
|||
// Check whether the XlaCoordinator has been initialized. | |||
virtual bool CoordinatorInitialized() const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be virtual? It looks like the implementations below don't depend on the underlying runtime client.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The XlaCoordinator depends on PJRT, so I kept it separate. Though I guess that's not a strong justification...
f20d522
to
1502aa8
Compare
b954fed
to
9aafdf5
Compare
@@ -30,6 +30,12 @@ | |||
namespace torch_xla { | |||
namespace runtime { | |||
|
|||
// Forward declare XlaCoordinator to avoid logging macro redefinition from the | |||
// transitively included PJRT header. | |||
// TODO(jonbolin): We need a way to ensure the right macros are included |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging macros are cursed
d2a36a6
to
673cad9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I had a hard time of imagining how this is going to be incorporated to the ckpt mgr. @jonb377 Can you point me to some examples?
@alanwaketan CheckpointManager will initialize the PreemptionSyncManager on construction and call into |
@alanwaketan See 5fdce13 for the intended usage. |
I see, that makes sense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
0d35028
to
cfbb93a
Compare
cfbb93a
to
6aab523
Compare
6aab523
to
d880511
Compare
d880511
to
287fa96
Compare
287fa96
to
bf27ec9
Compare
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
* Support PreemptionSyncManager in DistributedRuntime * Refactor to be owned by ComputationClient * Clean up logging macro issue handling
To support autocheckpointing upon preemption, we need to access a PreemptionSyncManager to identify sync points when a preemption has occurred.
This change additionally refactors the DistributedRuntime to be owned by the ComputationClient, since in the GPU case the ComputationClient has a direct dependency on the DistributedRuntimeClient.
This change adds the PreemptionSyncManager to the new XlaCoordinator class. The PreemptionSyncManager has the side effect of registering a SIGTERM handler, so it is not enabled by default.