-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model parallel distribution. #797
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -289,6 +289,100 @@ def get_variable_layout(self, variable): | |
return TensorLayout(variable_shard_spec, self.device_mesh) | ||
|
||
|
||
class ModelParallel(Distribution): | ||
"""Distribution that shard model weights. | ||
|
||
Compare to DataParallel which replicates the weights across all the devices, | ||
ModelParallel allows user to shard weights in addition to the input data. | ||
|
||
To construct a ModelParallel distribution, user need to provide device mesh | ||
and layout mapping. | ||
|
||
1. `DeviceMesh`contains physcial device information, and the axis names in | ||
the mesh will be used to map the weight and data layout. | ||
2. `LayoutMap` contains the mapping for the variable path to its | ||
corresponding `TensorLayout`. | ||
|
||
Example: | ||
``` | ||
devices = list_devices() # Assume there are 8 devices. | ||
|
||
# Create a mesh with 2 devices on data parallel and 4 devices on weight | ||
# parallel. | ||
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), | ||
devices=devices) | ||
# Create a layout map that shard the dense layer and conv2d layer weights | ||
# on the last dimension. Based on the device_mesh, this means the weights | ||
# will be split across 4 devices. Any other weights that doesn't match for | ||
# any key in layout map will get be fully replicated. | ||
layout_map = LayoutMap(device_mesh) | ||
layout_map['.*dense.*kernel'] = TensorLayout([None, 'model']) | ||
layout_map['.*dense.*bias'] = TensorLayout(['model']) | ||
layout_map['.*conv2d.*kernel'] = TensorLayout([None, None, None, 'model']) | ||
layout_map['.*conv2d.*bias'] = TensorLayout(['model']) | ||
|
||
distribution = ModelParallel(device_mesh=device_mesh, | ||
layout_map=layout_map, | ||
batch_dim_name='batch') | ||
with distribution.scope(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the primary usage is via the global setter, we should show that in the code example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
model = model_creation() | ||
model.compile() | ||
model.fit(data) | ||
``` | ||
|
||
User can quickly update the device mesh shape to change the sharding factor | ||
of the weights. E.g. | ||
``` | ||
# With only the shape change for the device mesh, the weights will be | ||
# sharded across 8 devices instead of 4, which further reduce the memory | ||
# footprint of weights on each of the device. | ||
device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'), | ||
devices=devices) | ||
``` | ||
|
||
To figure out a proper layout mapping rule for all the model weights, you | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do variables other than weights ever need to be sharded? e.g. optimizer variables, metrics. I assume optimizer variables will need to be sharded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the DTensor implementation, all the weights need to be either replicated/sharded. The normal tf.Variable doesn't work with DTensor variable within one tf.function. For JAX, it might not be explicitly required, but might have been done (replicated by default) under the hood. For optimizer variables, since it has similar name/path as the weights name, it will probably get the same layout as the variable. The metric variable are by default replicated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this case, we should show the code example with
This isn't the case today -- try printing some optimizer variables for a given model for some examples. We should figure out the recommended best practices for specifying layouts for optimizer variables. |
||
can first list out all the model weights path, which will be used as the key | ||
to map the weights to `TensorLayout`. | ||
|
||
e.g. | ||
``` | ||
model = create_model() | ||
for w in model.weights: | ||
print(w.path) | ||
``` | ||
""" | ||
|
||
def __init__(self, device_mesh, layout_map, batch_dim_name=None): | ||
"""Initialize the model parallel distribution. | ||
|
||
Args: | ||
device_mesh: `DeviceMesh` instance for physical device and its | ||
logical mapping. | ||
layout_map: `LayoutMap` instance which map the variable path to the | ||
corresponding `TensorLayout`. The axis names of the | ||
`TensorLayout`s should match to the axis names in the | ||
device_mesh, or exception will be raised. | ||
batch_dim_name: optional string, the axis name in the device_mesh | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this argument necessary? Can it not be inferred every time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope to not infer it which could lead to some bizarre behavior, and in JAX backend, it might sliently run without raising error. |
||
that will be used to distribute data. The first axis from the | ||
device_mesh will be used if user didn't specify any. | ||
""" | ||
super().__init__(device_mesh) | ||
self._layout_map = layout_map | ||
self._batch_dim_name = batch_dim_name or self.device_mesh.axis_names[0] | ||
|
||
def get_data_layout(self, data_shape): | ||
data_shard_spec = [None] * len(data_shape) | ||
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim | ||
return TensorLayout(data_shard_spec, self.device_mesh) | ||
|
||
def get_variable_layout(self, variable): | ||
variable_layout = self._layout_map[variable.path] | ||
if variable_layout is not None: | ||
return variable_layout | ||
variable_shard_spec = [None] * len(variable.shape) | ||
return TensorLayout(variable_shard_spec, self.device_mesh) | ||
|
||
|
||
class LayoutMap(collections.abc.MutableMapping): | ||
"""A dict-like object that maps string to `TensorLayout` instances. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add:
python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.