-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug Fixed GPTNeoX Flax supports #25334
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking super clean already @HeegyuKim! Mainly just some very minor comments from me - the overall design is great. Once these are addressed and the tests pass we can get it ready for merge 🚀
I suffering from test issue. Can you help me? @sanchit-gandhi summary
I don't think this is a problem with my model implementation. I wonder why pytorch's test fails. This PR failed two tests below
But two flax tests in tests/models/gpt_neox/test_modeling_flax_gptneox.py are fine. the test code which was copied from #24002 override both test_equivalence_pt_to_flax and test_equivalence_flax_to_pt methods with this comment.
and they use below assert code
instead of This overrides are equal to |
Hey @HeegyuKim - could you confirm that you get the same logits out from the Flax model when you generate as with the PyTorch model? i.e. that the generation scores are the same in both cases. If this is indeed the case, then we can know for certain that the Flax implementation is correct, and that we need to override the PT-FX cross tests. Otherwise, there's a divergence that we need to fix! We can check this with the full GPT NeoX model to ensure we have the right logits here |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hey @HeegyuKim! I thought a little bit more about the PT-FX cross tests in the [WIP] Flax LLaMA port with @vvvm23, and suggested that probably the reason for the failing tests is the random attention mask: #24587 (comment) If we switch to using a causal attention mask, we are able to get PT-FX equivalence for Flax LLaMA without overriding the tests. Since Flax LLaMA is heavily based off Flax GPT-Neo, I'm fairly certain we'll observe similar behaviour for Flax GPT-NeoX Would you like to try running the tests using a causal attention mask? E.g. as per #24587 (comment) |
Hi! Thanks @HeegyuKim for the PR. I am wondering if there is any update on this? It would be really cool if we could use GPTNeoXForCausalLM in flax! |
Hello @liutianlin0121 I'm trying to solve the problem whenever I have time. However, even if causal masking is applied, the error in the model output is still larger than 1e-5. The current error is around 0.02-0.03. I'm going to try again this weekend. Even though there are errors, the model works better than expected. I trained several models with this code.
I want to contribute to huggingface but it's not as easy as I thought. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you're making solid progress @HeegyuKim! Nice work! The most recent CI run is reporting that the difference in the Flax and PyTorch model outputs is > 1 (see CI Output). This suggests to me that there is a divergence between the Flax and PyTorch models. Generally, for any model less than 1B params, we should be able to get equivalence to within 1e-5 between Flax and PyTorch. It's quite likely that you won't get this equivalence running the matmuls in bfloat16 on TPU. But you should be able to running the matmuls in float32, see #15754 and jax-ml/jax#10413 (comment) for details
Here's a script that I used previously for checking PT / Flax equivalence for BLOOM: https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb You can ignore the bits about JIT'ing the forward pass for the time being. You can also uncomment the check to run it on CPU to force the highest precision, or use the decorator as provided
If we don't get 1e-5 precision, it's usually an indicator that we have a divergence in our model. Here, going through layer-by-layer and checking the hidden-states might be required to pinpoint it. Once you have this equivalence, it's almost guaranteed that the CI will report a difference of less than 1e-5, since it runs on CPU.
Let me know if you have any questions / queries about finishing this PR. You've done a great job and I'd be more than happy to assist you in seeing this to completion!
Ohhhhh I finally pass the equivalence issue! 🎉🎉
But there are CI failures...
|
Well done @HeegyuKim, that's excellent news! Regarding the two failing tests:
from transformers import FlaxAutoModelForCausalLM
model = FlaxAutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", from_pt=True) And then push the converted Flax weights to the Hub: model.push_to_hub("EleutherAI/gpt-neox-20b", create_pr=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks in great shape @HeegyuKim! And nice job on getting equivalence with PyTorch! Left a few suggestions below, mainly just small re-factoring to get it ready for merge. Feel free to ping me as soon as you're ready for a final look - think it should be pretty fast to get it merged from here
|
||
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") | ||
|
||
self.rotary_emb = FlaxGPTNeoXRotaryEmbedding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've only implemented one of the three possible rotary embedding types (rope_scaling=None
). There are two more RoPE types in the PyTorch modelling code:
def _init_rope(self): |
I don't think it would be too much work to add these so that we have equivalence with the PyTorch modelling code? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied RoPE scaling code and test code. But in test CI, the frozen flax model raises SetAttributeFrozenModuleError.
For the cached RoPE embedding, I think I should use the variable. I'm trying to implement it, and I think it'd be nice if you had an appropriate reference or suggestion. @sanchit-gandhi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah we can only set attributes in Flax in the setup
method. After that, the module gets frozen, meaning we can't add new attributes or update existing ones. So probably the cached embedding isn't going to work - what we can instead do is always initialise the embeddings to max length (config.max_position_embeddings
), and then slice the first N
embeddings as required each time. This way, we don't ever need to re-compute or update the embeddings, since we always have the max embedding length we require stored in the setup
|
||
input_mask = None | ||
if self.use_input_mask: | ||
input_mask = np.tril(np.ones((self.batch_size, self.seq_length))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer: we use a causal attention mask in the Flax generation tests. This is required to get sensible outputs from the Flax model (which is typical behaviour).
I think we're almost at the end of our work but there are small issues. Suddenly wav2vec2 test fails??
Flax weights
|
Hey @HeegyuKim! Nice job on iterating here! Answering your questions in-line below:
Note that it's important you force push (
|
Finally documentation is left, how can I make a documentation for it? @sanchit-gandhi
|
You can do so with |
I may passed necessary CI tests! @sanchit-gandhi |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice @HeegyuKim! Sorry about the delay with getting you another review. It's mainly small nits from me. Let's put in the request for a core maintainer to take a final look at this and get it merged!
@@ -63,12 +63,17 @@ def quick_gelu(x): | |||
return x * jax.nn.sigmoid(1.702 * x) | |||
|
|||
|
|||
def gelu_fast(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
return self.cos_cached, self.sin_cached | ||
|
||
def _compute_cos_sin(self, seq_len): | ||
t = jnp.arange(seq_len, dtype=self.inv_freq.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to reviewer: single-letter variables chosen to maintain equivalence with the PyTorch modelling code
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
return jnp.concatenate((-second_half, first_half), axis=-1) | ||
|
||
|
||
class FlaxGPTNeoXRotaryEmbedding(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The embedding classes are very nice! Ported to JAX while following closely the logic from PyTorch
return unfreeze(init_variables["cache"]) | ||
|
||
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) | ||
def __call__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not copy this entire method from GPT Neo as well? The code looks to be one-to-one the same now, it's just a comment which is different "# if past_key_values are passed...
"
If we do so, then we can actually just copy the entire class from GPT Neo, which would make the copied from statements much simpler.
|
||
hidden_states = outputs[0] | ||
|
||
lm_logits = self.embed_out(hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about possibly tied word embeddings?
transformers/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
Lines 635 to 637 in 2c658b5
if self.config.tie_word_embeddings: | |
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T | |
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work here! @HeegyuKim Thanks for adding the flax support 🤗
@@ -0,0 +1,783 @@ | |||
# coding=utf-8 | |||
# Copyright 2023 The EleutherAI and The HuggingFace Inc. team. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023 The EleutherAI and The HuggingFace Inc. team. | |
# Copyright 2023 The HuggingFace Inc. team. |
cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0) | ||
sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we got rid of the extra dimensions in the pytorch version #26162
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] | ||
|
||
|
||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the pointers! 😉
attention_mask = combine_masks(pad_mask, attention_mask) | ||
return key, value, attention_mask | ||
|
||
def _split_heads(self, hidden_states): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's tiny bit counter intuitive that the split_head
does not split! It's a nit but here's what we have in bloom for the same function:
transformers/src/transformers/models/persimmon/modeling_persimmon.py
Lines 237 to 252 in 587b8e6
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads | |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory | |
storage as `fused_qkv` | |
Args: | |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] | |
Returns: | |
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] | |
value: [batch_size, seq_length, num_heads, head_dim] | |
""" | |
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape | |
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) | |
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(other jax implementation have the same so feel free to choose what you prefere
def _merge_heads(self, hidden_states): | ||
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only used once so a bit useless, let's remove it
dropout_rng=dropout_rng, | ||
dropout_rate=self.config.attention_dropout, | ||
deterministic=deterministic, | ||
dtype=jnp.promote_types(self.dtype, jnp.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have to use this? (Just FMI 🤗 )
return (hidden_states,) + attn_outputs[1:] | ||
|
||
|
||
class FlaxGPTNeoXPreTrainedModel(FlaxPreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll have a # Ignore copy
soon 😉
The PR is almost finished! Would you like to make the last remaining changes @HeegyuKim such that we can get this one merged? Let us know if you have any questions, more than happy to help here |
Thank you for your comment! I'll check it this weekend |
This reverts commit a670443.
Hello @sanchit-gandhi, I rebased this PR to main branch and pushed again. There are two CI failures - First is a documentation issue.
This problem can be fixed when this GPT-NeoX model PR is merged. Alternatively, we can add from_pt=True to the example. As for the second issue, I don't know why. I would appreciate it if you could tell me the cause and solution to this problem.
I ran
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Fixes #22950:
@sanchit-gandhi