Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_layout_map() for all backbones #1689

Open
mattdangerw opened this issue Jul 10, 2024 · 1 comment
Open

Add get_layout_map() for all backbones #1689

mattdangerw opened this issue Jul 10, 2024 · 1 comment
Assignees
Labels
Gemma Gemma model specific issues type:feature New feature or request

Comments

@mattdangerw
Copy link
Member

We want model parallelism to be easy to use across the library.

We should add a get_layout_map() implementation to all backbones. This should be mostly copy paste from the Gemma version, since all transformers are pretty much the same weight structure.

@staticmethod
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
The returned `LayoutMap` contains the sharding spec for the gemma
backbone weights, so that you can use it to distribute weights across
the accelerators.
Example:
```
# Feel free to change the mesh shape to balance data and model parallel
mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=('batch', 'model'),
devices=keras.distribution.list_devices())
layout_map = GemmaBackbone.get_layout_map(
mesh, model_parallel_dim_name="model")
distribution = keras.distribution.ModelParallel(
mesh, layout_map, batch_dim_name='batch')
with distribution.scope():
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset()
```
Args:
device_mesh: The `keras.distribution.DeviceMesh` instance for
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
of all the model weights.
"""
# The weight path and shape of the Gemma backbone is like below (for 2G)
# token_embedding/embeddings, (256128, 2048), 524550144
# repeat block for decoder
# ...
# decoder_block_17/pre_attention_norm/scale, (2048,), 2048
# decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304
# decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304
# decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304
# decoder_block_17/pre_ffw_norm/scale, (2048,), 2048
# decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432
# decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432
if not isinstance(device_mesh, keras.distribution.DeviceMesh):
raise ValueError(
"Invalid device_mesh type. Expected `keras.distribution.Device`,"
f" got {type(device_mesh)}"
)
if model_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
model_dim,
data_dim,
None,
)
layout_map["decoder_block.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
)
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)
return layout_map

@mattdangerw mattdangerw added the type:feature New feature or request label Jul 10, 2024
@github-actions github-actions bot added the Gemma Gemma model specific issues label Jul 10, 2024
@mattdangerw
Copy link
Member Author

We should also keep the docstring for the method on the Backbone base class. And factor out all the error checking somehow. That way the per model code here could be really minimal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants