Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model parallel distribution. #797

Merged
merged 2 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions keras_core/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,102 @@ def get_variable_layout(self, variable):
return TensorLayout(variable_shard_spec, self.device_mesh)


class ModelParallel(Distribution):
"""Distribution that shard model weights.

Compare to DataParallel which replicates the weights across all the devices,
ModelParallel allows user to shard weights in addition to the input data.

To construct a ModelParallel distribution, user need to provide device mesh
and layout mapping.

1. `DeviceMesh`contains physcial device information, and the axis names in
the mesh will be used to map the weight and data layout.
2. `LayoutMap` contains the mapping for the variable path to its
corresponding `TensorLayout`.

Example:
```python
devices = list_devices() # Assume there are 8 devices.

# Create a mesh with 2 devices on data parallel and 4 devices on weight
# parallel.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
devices=devices)
# Create a layout map that shard the dense layer and conv2d layer weights
# on the last dimension. Based on the device_mesh, this means the weights
# will be split across 4 devices. Any other weights that doesn't match for
# any key in layout map will get be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['.*dense.*kernel'] = TensorLayout([None, 'model'])
layout_map['.*dense.*bias'] = TensorLayout(['model'])
layout_map['.*conv2d.*kernel'] = TensorLayout([None, None, None, 'model'])
layout_map['.*conv2d.*bias'] = TensorLayout(['model'])

distribution = ModelParallel(device_mesh=device_mesh,
layout_map=layout_map,
batch_dim_name='batch')
# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)

model = model_creation()
model.compile()
model.fit(data)
```

User can quickly update the device mesh shape to change the sharding factor
of the weights. E.g.
```
# With only the shape change for the device mesh, the weights will be
# sharded across 8 devices instead of 4, which further reduce the memory
# footprint of weights on each of the device.
device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'),
devices=devices)
```

To figure out a proper layout mapping rule for all the model weights, you
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do variables other than weights ever need to be sharded? e.g. optimizer variables, metrics. I assume optimizer variables will need to be sharded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the DTensor implementation, all the weights need to be either replicated/sharded. The normal tf.Variable doesn't work with DTensor variable within one tf.function.

For JAX, it might not be explicitly required, but might have been done (replicated by default) under the hood.

For optimizer variables, since it has similar name/path as the weights name, it will probably get the same layout as the variable. The metric variable are by default replicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the DTensor implementation, all the weights need to be either replicated/sharded

In this case, we should show the code example with model.variables which is more complete. Or we could move get_variable_map() to the public API and recommend that. It's probably the best move since it includes optimizer variables as well.

For optimizer variables, since it has similar name/path as the weights name

This isn't the case today -- try printing some optimizer variables for a given model for some examples. We should figure out the recommended best practices for specifying layouts for optimizer variables.

can first list out all the model weights path, which will be used as the key
to map the weights to `TensorLayout`.

e.g.
```
model = create_model()
for w in model.weights:
print(w.path)
```
"""

def __init__(self, device_mesh, layout_map, batch_dim_name=None):
"""Initialize the model parallel distribution.

Args:
device_mesh: `DeviceMesh` instance for physical device and its
logical mapping.
layout_map: `LayoutMap` instance which map the variable path to the
corresponding `TensorLayout`. The axis names of the
`TensorLayout`s should match to the axis names in the
device_mesh, or exception will be raised.
batch_dim_name: optional string, the axis name in the device_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this argument necessary? Can it not be inferred every time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope to not infer it which could lead to some bizarre behavior, and in JAX backend, it might sliently run without raising error.

that will be used to distribute data. The first axis from the
device_mesh will be used if user didn't specify any.
"""
super().__init__(device_mesh)
self._layout_map = layout_map
self._batch_dim_name = batch_dim_name or self.device_mesh.axis_names[0]

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
variable_layout = self._layout_map[variable.path]
if variable_layout is not None:
return variable_layout
variable_shard_spec = [None] * len(variable.shape)
return TensorLayout(variable_shard_spec, self.device_mesh)


class LayoutMap(collections.abc.MutableMapping):
"""A dict-like object that maps string to `TensorLayout` instances.

Expand Down
88 changes: 87 additions & 1 deletion keras_core/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,53 @@ def test_get_variable_layout(self):
self.assertEqual(variable_layout.axes, [None])


class ModelParallelDistributionTest(testing.TestCase):
def setUp(self):
super().setUp()
self.devices = ["CPU:{i}" for i in range(8)]
shape = (2, 4)
axis_names = ["data", "model"]

self.device_mesh = distribution_lib.DeviceMesh(
shape, axis_names, self.devices
)

def test_distribute_weights(self):
layout_map = distribution_lib.LayoutMap(self.device_mesh)
layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"])
layout_map[".*bias"] = distribution_lib.TensorLayout(["model"])

distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
)
kernel = backend.Variable(initializer=np.arange(8, 4), name="kernel")
bias = backend.Variable(initializer=np.arange(4), name="bias")
rng_seed = backend.Variable(initializer=[0, 1], name="seed")

kernel_layout = distribution.get_variable_layout(kernel)
self.assertIs(kernel_layout.device_mesh, self.device_mesh)
self.assertEqual(kernel_layout.axes, [None, "model"])

bias_layout = distribution.get_variable_layout(bias)
self.assertIs(bias_layout.device_mesh, self.device_mesh)
self.assertEqual(bias_layout.axes, ["model"])

rng_seed_layout = distribution.get_variable_layout(rng_seed)
self.assertIs(rng_seed_layout.device_mesh, self.device_mesh)
self.assertEqual(rng_seed_layout.axes, [None])

def test_distribute_data(self):
layout_map = distribution_lib.LayoutMap(self.device_mesh)
distribution = distribution_lib.ModelParallel(
self.device_mesh, layout_map, batch_dim_name="data"
)

data = np.arange(16).reshape((4, 2, 2))
data_layout = distribution.get_data_layout(data.shape)
self.assertIs(data_layout.device_mesh, self.device_mesh)
self.assertEqual(data_layout.axes, ["data", None, None])


class LayoutMapTest(testing.TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -345,7 +392,7 @@ def test_validation_for_device_mesh(self):
):
backend_dlib.to_jax_layout(layout)

def test_e2e_model(self):
def test_e2e_data_parallel_model(self):
distribution = distribution_lib.DataParallel(
devices=backend_dlib.list_devices()
)
Expand All @@ -368,3 +415,42 @@ def test_e2e_model(self):
with distribution.scope():
model.compile(loss="mse")
model.fit(inputs, labels)

def test_e2e_model_parallel_model(self):
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
shape, axis_names, backend_dlib.list_devices()
)

layout_map = distribution_lib.LayoutMap(device_mesh)
layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout(
[None, "model"]
)
layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"])

distribution = distribution_lib.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch"
)
with distribution.scope():
inputs = layers.Input(shape=[28, 28, 1])
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = models.Model(inputs=inputs, outputs=y)

for weight in model.weights:
if "kernel" in weight.name:
self.assertEqual(weight._value.sharding.spec, (None, "model"))
elif "bias" in weight.name:
self.assertEqual(weight._value.sharding.spec, ("model",))
else:
self.assertTrue(weight._value.sharding.is_fully_replicated)

inputs = np.random.normal(size=(32, 28, 28, 1))
labels = np.random.normal(size=(32, 10))

with distribution.scope():
model.compile(loss="mse")
model.fit(inputs, labels)