-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras 3.2.1 breaks LoRA with ModelParallel #19496
Comments
I didn't spot the issue on Keras side. The Could it possibly be a bug in jax? |
I tried this script on Kaggle and got a different error when executing >2 times of: # Create a device mesh with shape (1, 8) to parition weights across all 8 TPUs cores.
devices = keras.distribution.list_devices() # 8 TPUs
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)
# Create a LayoutMap to partition relevant weights
layout_map = keras_nlp.models.GemmaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)
# Make ModelParallel laoading using the LayoutMap the default
keras.distribution.set_distribution(distribution)
# Initialize GemmaBackbone
gemma_backbone = keras_nlp.models.GemmaBackbone.from_preset("gemma_1.1_instruct_2b_en") The error: Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Kaggle notebook...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[26], line 1
----> 1 gemma_backbone = keras_nlp.models.GemmaBackbone.from_preset("gemma_1.1_instruct_2b_en")
File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/backbone.py:200, in Backbone.from_preset(cls, preset, load_weights, **kwargs)
194 if not issubclass(preset_cls, cls):
195 raise ValueError(
196 f"Preset has type `{preset_cls.__name__}` which is not a "
197 f"a subclass of calling class `{cls.__name__}`. Call "
198 f"`from_preset` directly on `{preset_cls.__name__}` instead."
199 )
--> 200 return load_from_preset(
201 preset,
202 load_weights=load_weights,
203 config_overrides=kwargs,
204 )
File /usr/local/lib/python3.10/site-packages/keras_nlp/src/utils/preset_utils.py:376, in load_from_preset(preset, load_weights, config_file, config_overrides)
374 config = json.load(config_file)
375 config["config"] = {**config["config"], **config_overrides}
--> 376 layer = keras.saving.deserialize_keras_object(config)
378 # Load any assets for our tokenizers.
379 tokenizer = get_tokenizer(layer)
File /usr/local/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:711, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
709 with custom_obj_scope, safe_mode_scope:
710 try:
--> 711 instance = cls.from_config(inner_config)
712 except TypeError as e:
713 raise TypeError(
714 f"{cls} could not be deserialized properly. Please"
715 " ensure that components that are Python object"
(...)
719 f"\n\nconfig={config}.\n\nException encountered: {e}"
720 )
File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/backbone.py:135, in Backbone.from_config(cls, config)
131 @classmethod
132 def from_config(cls, config):
133 # The default `from_config()` for functional models will return a
134 # vanilla `keras.Model`. We override it to get a subclass instance back.
--> 135 return cls(**config)
File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/gemma/gemma_backbone.py:149, in GemmaBackbone.__init__(self, vocabulary_size, num_layers, num_query_heads, num_key_value_heads, hidden_dim, intermediate_dim, head_dim, layer_norm_epsilon, dropout, dtype, **kwargs)
147 x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype)
148 for transformer_layer in self.transformer_layers:
--> 149 x = transformer_layer(x, padding_mask=padding_mask_input)
150 sequence_output = self.layer_norm(x)
151 super().__init__(
152 inputs={
153 "token_ids": token_id_input,
(...)
158 **kwargs,
159 )
File /usr/local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/gemma/gemma_decoder_block.py:96, in GemmaDecoderBlock.build(self, input_shape)
94 def build(self, input_shape):
95 self.pre_attention_norm.build(input_shape)
---> 96 self.attention.build(input_shape)
98 shape = input_shape
99 self.pre_ffw_norm.build(shape)
File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/gemma/gemma_attention.py:64, in CachedGemmaAttention.build(self, inputs_shape)
55 self.query_dense.build(inputs_shape)
57 self.key_dense = keras.layers.EinsumDense(
58 "bsd,kdh->bskh",
59 output_shape=(None, self.num_key_value_heads, self.head_dim),
(...)
62 name="key",
63 )
---> 64 self.key_dense.build(inputs_shape)
66 self.value_dense = keras.layers.EinsumDense(
67 "bsd,kdh->bskh",
68 output_shape=(None, self.num_key_value_heads, self.head_dim),
(...)
71 name="value",
72 )
73 self.value_dense.build(inputs_shape)
File /usr/local/lib/python3.10/site-packages/jax/_src/api.py:2519, in device_put(x, device, src)
2514 if ((device is None or
2515 isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
2516 (src is None or
2517 isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
2518 for leaf in tree_leaves(x):
-> 2519 _check_sharding(leaf, s=device)
2520 return tree_map(
2521 lambda y: dispatch.device_put_p.bind(
2522 y, device=device, src=_infer_src_sharding(src, y)), x)
2524 x_flat, treedef = tree_flatten(x)
File /usr/local/lib/python3.10/site-packages/jax/_src/api.py:2482, in _check_sharding(x, s)
2480 aval = shaped_abstractify(x)
2481 if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
-> 2482 pjit.pjit_check_aval_sharding(
2483 (s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
2484 s.shard_shape(aval.shape)
File /usr/local/lib/python3.10/site-packages/jax/_src/pjit.py:1034, in pjit_check_aval_sharding(shardings, flat_avals, names, what_aval, allow_uneven_sharding)
1032 for i, size in enumerate(num_ways_dim_sharded):
1033 if not allow_uneven_sharding and shape[i] % size != 0:
-> 1034 raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
1035 f"of {s}, which implies that "
1036 f"the global size of its dimension {i} should be "
1037 f"divisible by {size}, but it is equal to {shape[i]} "
1038 f"(full shape: {shape})")
ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('batch': 1, 'model': 8), spec=PartitionSpec('model', 'batch', None)), which implies that the global size of its dimension 0 should be divisible by 8, but it is equal to 1 (full shape: (1, 2048, 256)) This should indicate that there is an issue with the model initialization of |
has anyone found the solution or any workaround? |
Meet the same issue on Kaggle. |
change keras version to 3.1.1 |
… written by bballe@
Started looking... I think the issue is a bug in the guides (and We can restore the original behavior or the guide by updating our layout map paths a little bit to be stricter (and not select lora kernel variables)... # Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output/kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear/kernel"] = ("model", None) It also might be ok to shard the lora kernel "A" variables on the same rank as our actual kernels, but I can't imaging that would impact the runtime very much given the lora variable sizes. I'll test this out tomorrow. |
Tested, the fix works well. Thank you! |
Possibly related: keras-team/keras-nlp#1613 |
Hi @martin-gorner , Can we mark this resolved now. Also it seems to be Keras_NLP specific as mentioned in this comment. |
yes, this is resolved |
* fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per keras-team/keras#19496 (comment) * fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per keras-team/keras#19496 (comment) * Also fixing forgotten ffw_gating_2 in GemmaBackbone.get_layout_map. The sharding spec ("batch", "model") is the one that provides the best training performance. ("batch", "model") and (None, None) are slower (the first one by 40%, the second by 2%). Fixing test too, including typo ffw_linearl => ffw_linear * changed test_architecture_characteristics test to follow the 4->8 heads change necessary for the test to work on TPUs. Also fixed formatting. * Update gemma_backbone_test.py Better test messages --------- Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>
You can test with the Gemma chat demo: bit.ly/gemma-pirate-demo
With keras 3.1.1, the line
gemma_lm.backbone.enable_lora(rank=8)
executes successfully.With keras 3.2.1, the line
gemma_lm.backbone.enable_lora(rank=8)
errors out with:The text was updated successfully, but these errors were encountered: