Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GemmaBackbone.get_layout_map broken for gemma_2b_en #1613

Open
josharian opened this issue May 3, 2024 · 5 comments
Open

GemmaBackbone.get_layout_map broken for gemma_2b_en #1613

josharian opened this issue May 3, 2024 · 5 comments
Assignees
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng type:Bug Something isn't working

Comments

@josharian
Copy link

Describe the bug

When attempting to shard a gemma_2b_en model across two (consumer-grade) GPUs, I get:

ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('data': 1, 'model': 2), spec=PartitionSpec('model', 'data', None)), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 2048, 256))

The problem is the attention key/value kernels. gemma_2b_en decoder layer shapes:

decoder_block_0/pre_attention_norm/scale                    (2048,)         
decoder_block_0/attention/query/kernel                      (8, 2048, 256)  
decoder_block_0/attention/key/kernel                        (1, 2048, 256)  
decoder_block_0/attention/value/kernel                      (1, 2048, 256)  
decoder_block_0/attention/attention_output/kernel           (8, 256, 2048)  
decoder_block_0/pre_ffw_norm/scale                          (2048,)         
decoder_block_0/ffw_gating/kernel                           (2048, 16384)   
decoder_block_0/ffw_gating_2/kernel                         (2048, 16384)   
decoder_block_0/ffw_linear/kernel                           (16384, 2048)   

gemma_7b_en decoder layer shapes:

decoder_block_0/pre_attention_norm/scale                    (3072,)         
decoder_block_0/attention/query/kernel                      (16, 3072, 256) 
decoder_block_0/attention/key/kernel                        (16, 3072, 256) 
decoder_block_0/attention/value/kernel                      (16, 3072, 256) 
decoder_block_0/attention/attention_output/kernel           (16, 256, 3072) 
decoder_block_0/pre_ffw_norm/scale                          (3072,)         
decoder_block_0/ffw_gating/kernel                           (3072, 24576)   
decoder_block_0/ffw_gating_2/kernel                         (3072, 24576)   
decoder_block_0/ffw_linear/kernel                           (24576, 3072)   

Observe that the leading dimension of decoder_block.*attention.*(key|value).*kernel is divisible by 2/4/8/16 in gemma_7b_en but not in gemma_2b_en.

Additional context

This was introduced in #1491. layout_map["decoder_block.*attention.*(query|key|value).*kernel"] was changed from (None, None, model_dim) to (model_dim, data_dim, None).

cc @qlzh727 @mattdangerw

There are other issues filed around lora training and the layout_map regular expressions. This the unrelated; this reproduces without lora enabled.

Would you like to help us fix it?

Sure, although I don't know what the preferred fix is. One obvious choice would be to make this not a static method any more, so we can pick optimal layouts for each model size.

@github-actions github-actions bot added the Gemma Gemma model specific issues label May 3, 2024
@qlzh727
Copy link
Member

qlzh727 commented May 3, 2024

Thanks for the report. @mattdangerw, since the k/v shape are different from the q shape in 2b model, we might want to change the sharding spec for that, eg we could make it (None, data, None) since the first dim is always 1.

@josharian
Copy link
Author

(None, data, None)

I am new to this, so definitely don't listen to me too much...but for folks like me struggling to squish this onto consumer GPUs, it'd be nice to have some model parallelism everywhere.

@mattdangerw
Copy link
Member

mattdangerw commented Jul 10, 2024

Thanks @josharian. Finally getting around to this -- sorry for the delay!

I think the issue we have here is we have all of multi-head attention, multi-query attention, and now I think grouped-query attention (with Gemma 2) in the same Gemma architecture. To me that kinda suggests we have the wrong signature here; we need the model config to create the map. Instead of:

layout_map = keras_nlp.models.GemmaCausalLM.get_layout_map(mesh)
distribution = keras.distribution.ModelParallel(mesh, layout_map)
with distribution.scope():
    gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("./path_to_preset")

We might need:

layout_map = keras_nlp.models.GemmaCausalLM.get_layout_map("./path_to_preset", mesh)
distribution = keras.distribution.ModelParallel(mesh, layout_map)
with distribution.scope():
    gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("./path_to_preset")

Or maybe there's a better API we could have. The order of operations get's kinda awkward here. You need to create the layout map before you create the model, but you need the config of the model before you create the layout map.

@mattdangerw
Copy link
Member

One alternative:

device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), devices=devices)
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(
    "./path_to_preset",
    device_mesh=device_mesh,
)

And we actually enter into a ModelParallel device scope for you inside the from_preset call.

So either:

  1. Do you own ModelParallel scope, where you control everything.
  2. Just pass a mesh on construction and get an auto distributed model.

If you are creating your own model from scratch (via a direct construct call), you have to do the former, since we don't know the correct layout map to create.

@mattdangerw
Copy link
Member

@fchollet @martin-gorner any thoughts on this and the proposal in the last comment?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng type:Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants