Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix GemmaBackbone.get_layout_map + test #1669

Merged
merged 5 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions keras_nlp/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,18 @@ def get_layout_map(
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
model_dim,
data_dim,
None,
)
layout_map["decoder_block.*attention_output.*kernel"] = (
layout_map["decoder_block.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
)
layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim)
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)

return layout_map
41 changes: 37 additions & 4 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def setUp(self):
self.init_kwargs = {
"vocabulary_size": 256128,
"num_layers": 2,
"num_query_heads": 4,
"num_key_value_heads": 4,
"num_query_heads": 8,
"num_key_value_heads": 8,
"hidden_dim": 128,
"intermediate_dim": 256,
"head_dim": 128,
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_all_presets(self):

def test_architecture_characteristics(self):
model = GemmaBackbone(**self.init_kwargs)
self.assertEqual(model.count_params(), 33407616)
self.assertEqual(model.count_params(), 33931904)
self.assertEqual(len(model.layers), 6)

def test_distribution(self):
Expand Down Expand Up @@ -132,7 +132,40 @@ def test_distribution(self):
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "ffw_linearl" in w.path:
if "ffw_linear" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)

def test_distribution_with_lora(self):
if keras.backend.backend() != "jax":
return
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
return
martin-gorner marked this conversation as resolved.
Show resolved Hide resolved
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = GemmaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)
with distribution.scope():
model = GemmaBackbone(**self.init_kwargs)
model.enable_lora(rank=4)

for w in model.weights:
if "attention/query/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "attention/query/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
if "attention/value/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
Loading