-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up and add doc for dynamo_mark_sharding #7281
Conversation
docs/spmd.md
Outdated
import torch_xla.experimental.dynamo_mark_sharding | ||
device_ids = [i for i in range(self.num_devices)] # List[int] | ||
mesh_shape = [self.num_devices//2, 1, 2] # List[int] | ||
axis_names = 'None' # string version of axis_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use axis_names
instead of numbers for partition_spec
? I want to update this doc to encourage user to give each mesh axis a name. I think this is more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! Just updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in my new pr #7285 I refactored the spmd.md
. I can let you merge first and I will move your chagne to spmd_advanced.md
docs/spmd.md
Outdated
device_ids = [i for i in range(self.num_devices)] # List[int] | ||
mesh_shape = [self.num_devices//2, 1, 2] # List[int] | ||
axis_names = '(data, model)' # string version of axis_names | ||
partition_spec = '(model, data)' # string version of partition spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this right? I believe you need "('data', 'model')"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right. Updated.
docs/spmd.md
Outdated
import torch_xla.experimental.dynamo_mark_sharding | ||
device_ids = [i for i in range(self.num_devices)] # List[int] | ||
mesh_shape = [self.num_devices//2, 1, 2] # List[int] | ||
axis_names = '(data, model)' # string version of axis_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, I think you need "('data', 'model')"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
Skipping the CI as it is just a docs change. |
use_dynamo_custom_op
since it's not used anymore)