Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up and add doc for dynamo_mark_sharding #7281

Merged
merged 4 commits into from
Jun 14, 2024
Merged

Conversation

wonjoolee95
Copy link
Collaborator

  1. Clean up unused code for dynamo_mark_sharding (remove flag use_dynamo_custom_op since it's not used anymore)
  2. Update spmd docs with an example.

@wonjoolee95 wonjoolee95 requested a review from JackCaoG June 14, 2024 20:11
docs/spmd.md Outdated
import torch_xla.experimental.dynamo_mark_sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = 'None' # string version of axis_names
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use axis_names instead of numbers for partition_spec ? I want to update this doc to encourage user to give each mesh axis a name. I think this is more clear.

Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Just updated.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my new pr #7285 I refactored the spmd.md. I can let you merge first and I will move your chagne to spmd_advanced.md

docs/spmd.md Outdated
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = '(data, model)' # string version of axis_names
partition_spec = '(model, data)' # string version of partition spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this right? I believe you need "('data', 'model')"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right. Updated.

docs/spmd.md Outdated
import torch_xla.experimental.dynamo_mark_sharding
device_ids = [i for i in range(self.num_devices)] # List[int]
mesh_shape = [self.num_devices//2, 1, 2] # List[int]
axis_names = '(data, model)' # string version of axis_names
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, I think you need "('data', 'model')"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@wonjoolee95
Copy link
Collaborator Author

Skipping the CI as it is just a docs change.

@wonjoolee95 wonjoolee95 merged commit 61e0389 into master Jun 14, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants