-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model parallel distribution. #797
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
distribution = ModelParallel(device_mesh=device_mesh, | ||
layout_map=layout_map, | ||
batch_dim_name='batch') | ||
with distribution.scope(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the primary usage is via the global setter, we should show that in the code example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
corresponding `TensorLayout`. | ||
|
||
Example: | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add: python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
devices=devices) | ||
``` | ||
|
||
To figure out a proper layout mapping rule for all the model weights, you |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do variables other than weights ever need to be sharded? e.g. optimizer variables, metrics. I assume optimizer variables will need to be sharded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the DTensor implementation, all the weights need to be either replicated/sharded. The normal tf.Variable doesn't work with DTensor variable within one tf.function.
For JAX, it might not be explicitly required, but might have been done (replicated by default) under the hood.
For optimizer variables, since it has similar name/path as the weights name, it will probably get the same layout as the variable. The metric variable are by default replicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the DTensor implementation, all the weights need to be either replicated/sharded
In this case, we should show the code example with model.variables
which is more complete. Or we could move get_variable_map()
to the public API and recommend that. It's probably the best move since it includes optimizer variables as well.
For optimizer variables, since it has similar name/path as the weights name
This isn't the case today -- try printing some optimizer variables for a given model for some examples. We should figure out the recommended best practices for specifying layouts for optimizer variables.
corresponding `TensorLayout`. The axis names of the | ||
`TensorLayout`s should match to the axis names in the | ||
device_mesh, or exception will be raised. | ||
batch_dim_name: optional string, the axis name in the device_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this argument necessary? Can it not be inferred every time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope to not infer it which could lead to some bizarre behavior, and in JAX backend, it might sliently run without raising error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
No description provided.