Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model parallel distribution. #797

Merged
merged 2 commits into from
Aug 27, 2023

Conversation

qlzh727
Copy link
Member

@qlzh727 qlzh727 commented Aug 26, 2023

No description provided.

@qlzh727 qlzh727 requested a review from fchollet August 26, 2023 20:05
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

distribution = ModelParallel(device_mesh=device_mesh,
layout_map=layout_map,
batch_dim_name='batch')
with distribution.scope():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the primary usage is via the global setter, we should show that in the code example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

corresponding `TensorLayout`.

Example:
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add: python

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

devices=devices)
```

To figure out a proper layout mapping rule for all the model weights, you
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do variables other than weights ever need to be sharded? e.g. optimizer variables, metrics. I assume optimizer variables will need to be sharded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the DTensor implementation, all the weights need to be either replicated/sharded. The normal tf.Variable doesn't work with DTensor variable within one tf.function.

For JAX, it might not be explicitly required, but might have been done (replicated by default) under the hood.

For optimizer variables, since it has similar name/path as the weights name, it will probably get the same layout as the variable. The metric variable are by default replicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the DTensor implementation, all the weights need to be either replicated/sharded

In this case, we should show the code example with model.variables which is more complete. Or we could move get_variable_map() to the public API and recommend that. It's probably the best move since it includes optimizer variables as well.

For optimizer variables, since it has similar name/path as the weights name

This isn't the case today -- try printing some optimizer variables for a given model for some examples. We should figure out the recommended best practices for specifying layouts for optimizer variables.

corresponding `TensorLayout`. The axis names of the
`TensorLayout`s should match to the axis names in the
device_mesh, or exception will be raised.
batch_dim_name: optional string, the axis name in the device_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this argument necessary? Can it not be inferred every time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope to not infer it which could lead to some bizarre behavior, and in JAX backend, it might sliently run without raising error.

@qlzh727 qlzh727 requested a review from fchollet August 27, 2023 00:00
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit a051b5c into keras-team:main Aug 27, 2023
@qlzh727 qlzh727 deleted the true_model_parallel branch August 29, 2023 17:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants