Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable eager spmd #7341

Merged
merged 4 commits into from
Jun 26, 2024
Merged

Enable eager spmd #7341

merged 4 commits into from
Jun 26, 2024

Conversation

JackCaoG
Copy link
Collaborator

No description provided.

@JackCaoG JackCaoG force-pushed the JackCaoG/eager_spmd branch from 3ec9f92 to 4244dcb Compare June 25, 2024 22:55
@JackCaoG JackCaoG marked this pull request as ready for review June 26, 2024 18:30
@alanwaketan
Copy link
Collaborator

How could SPMD possibly work for eager mode?

@JackCaoG
Copy link
Collaborator Author

How could SPMD possibly work for eager mode?

consider eager mode as calling mark_step after every pytorch op.

@alanwaketan
Copy link
Collaborator

How could SPMD possibly work for eager mode?

consider eager mode as calling mark_step after every pytorch op.

Then how sharding propogation and auto partition work? I assume they don't carry states from last graph?

@JackCaoG
Copy link
Collaborator Author

How could SPMD possibly work for eager mode?

consider eager mode as calling mark_step after every pytorch op.

Then how sharding propogation and auto partition work? I assume they don't carry states from last graph?

The sharding propogation and auto partition still happening within the subgraph we compile. For example

t3 = t2.cos(t1)
t3 += t2

we will compile a graph for cos which will calculate the output sharding for t3 and then assign a PJRT sharded buffer that's not ready to t3. We will then just proceed with another graph with add and now we know the input sharding for t3, then we will just propagate that to the output.

@alanwaketan
Copy link
Collaborator

Okay, that's fair.

@@ -2742,7 +2742,9 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,

// 2) Aid SPMD.
XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
if (sharding && sharding->sharding.type() != xla::OpSharding::UNKNOWN) {
// don't propagate sharding in eager mode.
if (!XLAGraphExecutor::Get()->UseEagerMode() && sharding &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It complained about the output tensor already has a sharding and we can't propagate to it. This happens in the backward. I didn't spend enough time to debug it but I don't expect user to actually run eager mode with step fn(forward and backward), I only expect them to run it with some data preprocessing on device so I just quickly unblock myself.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks, Jack!

@JackCaoG JackCaoG merged commit d5e5713 into master Jun 26, 2024
23 checks passed
JackCaoG added a commit that referenced this pull request Jul 12, 2024
bhavya01 pushed a commit that referenced this pull request Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants