-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Abnormal behavior of OPT except OPT-350m #17653
Comments
Related issue #17545 |
Hmm, that's a very good test - it indeed seems like there is a bug with the model weights conversion... @stephenroller can you by any chance run those models with metaseq to verify that |
Also @hemingkx, I can confirm your results - just sent a mail to the authors since we think the problem lies in the model conversion |
Thanks for your early reply! Hope this can be resolved soon 😊. |
This is in my queue, just have a long queue. |
@stephenroller Big Science was hoping to compare the BLOOM model to OPT. Do you have any idea if it’s reasonable to expect that this is resolved by the end of the month? I would be very appreciative if you could keep me apprised of the progress. |
@stephenroller do you use Tensor Parallelism for OPT models >= 1.3b? (I can't find the information on the paper) |
We used tensor parallelism for all models EXCEPT 350 :P |
@patrickvonplaten then it might explain these numbers on all models except 350m ? |
Also do you remember which layer did you TP-ranked? |
EleutherAI has found that merging TP parallel models can be extremely non-intuitive, and it took some substantial work from @zphang to figure out how to do it while maintaining performance. You can find code for merging the TP parallel shards for our 20B model here and I think we have some more general code as well but I would have to hunt it down to release it. |
The final field in each of these (except 350M, where i trained it in a different one-off pipeline using MP1) |
Cool thanks! |
That's very interesting! Thanks for all the pointers here @StellaAthena @stephenroller @younesbelkada - let's try out whether this fixes it :-) |
Okay I think we've fixed the issue, I think it was caused because of poor conversion on our side where we missed those lines from Fix is in #17785 , preliminary tests on 125m and 1B3 showed that the fix significantly reduces the ppl. >>> model_path="fixed_opt_125m"
>>> prompt="Hello my name is"
>>> log_probs_with_ppl(model_path, prompt)
Input torch.Size([1, 5])
Logits torch.Size([1, 5, 50272])
torch.return_types.max(
values=tensor([[0.2398, 0.2326, 0.3332, 0.9363, 0.0097]], grad_fn=<MaxBackward0>),
indices=tensor([[ 100, 6, 766, 16, 1236]]))
argmax probility: [[0.23982257 0.23258895 0.33315504 0.9362957 0.00967377]]
argmax log probability: [[-1.4278558 -1.4584825 -1.0991473 -0.06582398 -4.6383367 ]]
argmax tokens: I, name is j
cross entropy loss: 4.051314830780029
ppl: 57.47297286987305 |
Hi @stephenroller ! We would like to have a final check about something while merging #17785 , does all OPT models have |
Yes, all models have that True. Yay, glad you were able to fix it! |
Hi, Thanks everyone for working on this issue! I'm not sure the issue is fully resolved. I found that the weight of the final layer norm is one-initialized, which is the default initialization in >>> from transformers import OPTModel
>>> model = OPTModel.from_pretrained("facebook/opt-13b")
>>> all(model.decoder.final_layer_norm.weight == 1)
True
>>> all(model.decoder.final_layer_norm.bias == 0)
False Same observation with the variants 2.7B and 6.7B. This seems unexpected. Is it possible the final layer norm weight was lost at some point during the conversion? |
We observed that models 13B and larger ended up learning layer norm weights of 1; we chalked it up to (1 + epsilon) precision issues. Can you check other layer norms for the same set of models? |
I checked with 2.7B which is faster to load and all layer norm weights are indeed 1: >>> from transformers import OPTModel
>>> model = OPTModel.from_pretrained("facebook/opt-2.7b")
>>> all(all(layer.final_layer_norm.weight == 1) for layer in model.decoder.layers)
True
>>> all(all(layer.self_attn_layer_norm.weight == 1) for layer in model.decoder.layers)
True So this seems to be the expected value based on your comment. Thanks @stephenroller! |
@patrickvonplaten IMO we should double check that, this seems like a highly unlikely thing for weights ... |
This doesn't happen for 125m: >>> from transformers import OPTModel
>>> model = OPTModel.from_pretrained("facebook/opt-125m")
>>> all(all(layer.final_layer_norm.weight == 1) for layer in model.decoder.layers)
False
>>> all(all(layer.self_attn_layer_norm.weight == 1) for layer in model.decoder.layers)
False
>>> all(model.decoder.final_layer_norm.weight == 1)
False So I'd expect this not to be an issue with conversion scripts of whatever .... Were you able to generate using 13b? Are the generations not as good as you expected? |
I'm working on a project where we apply 8-bit quantization to the OPT linear weights. It works well for 125m where we get the same output with and without quantization, but there are unexpected repetitions with the quantized 350m and larger variants. (OpenNMT/CTranslate2#818). I found it odd that the models we have some issues with are precisely the models where all layer norm weights are 1 (so far I verified this to be true for 350m, 1.3b, 2.7b, 6.7b, and 13b). As you said this is generally unlikely for weights so I thought this is worth mentioning. |
350m didn't have issues normally. I agree it's very unlikely, we're going to run a perplexity test to check that it is fairly close to gpt2. @stephenroller did you make it untrainable and just set it at 1? |
We didn't make it untrainable. We actually spent some time digging into this and had a quite vigorous internal debate on whether this was a "problem." Again, since we trained with memory efficient fp16 with often low LRs, my believe was that the gradients were experiencing (1 + epsilon) = 1 underflow type issues; epsilon is much higher than you might expect:
However, to my knowledge, this weight=1 issue didn't crop up until 13B scale... Let me see if I have some human readable checkpoints. |
To add to this, I just evaluated the perplexity on the corrected OPT checkpoints (the ones that are online on the Hub now) with the following script (slighly adapted from the original one): from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import torch.nn.functional as F
def log_probs_with_ppl(path, prompt):
model = AutoModelForCausalLM.from_pretrained(path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
logits = outputs.logits
arg_probs, _ = F.softmax(logits, dim=-1).max(-1)
print("argmax probility:", arg_probs[0].cpu().detach().numpy())
log_probs, tokens = F.log_softmax(logits, dim=-1).max(-1)
print("argmax log probability:", log_probs[0].cpu().detach().numpy())
sent = tokenizer.decode(tokens.squeeze().cpu().detach().numpy(), skip_special_tokens=False)
print("argmax tokens:", sent)
xentropy_loss = outputs[0]
print("cross entropy loss:", xentropy_loss.item())
ppl = torch.exp(xentropy_loss).item()
print("ppl:", ppl)
if __name__ == "__main__":
prompts = "There is a book on the desk."
for model_id in ["opt-125m", "opt-350m", "opt-1.3b", "opt-2.7b", "opt-6.7b", "opt-13b", "opt-30b"]:
print(20 * "=" + model_id + 20 * "=")
model_path = os.path.join("facebook", model_id)
log_probs_with_ppl(model_path, prompts) and the results look as follows (a bit weird that the 30B ppl is not lower)
|
30B looks kinda weird with its higher ppl than the rest... |
Yeah - tried out another prompt
30B still a bit higher. But note that all those are also run on CPU (currently don't have access to a big GPU) |
(sorry - there was one final problem with opt-30b, there was a typo with the begin token for 30B see: https://huggingface.co/facebook/opt-30b/commit/a7fad6ce41655b751249f8801cc3a24ede359d31). Updated results make more sense IMO: For
For
|
Awesome @patrickvonplaten ! So we assume the conversion was correct in that sense I guess. |
OPT-66B is up as well. Ran the above prompts only on GPU with fp16 so slightly less precision. For
For
|
That surprises me less. The 66B was actually harder to train than the 175B, weirdly. |
System Info
Who can help?
@patrickvonplaten @LysandreJik
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The scripts above are used to test both the gpt2 and opt models, the results are shown below:
Sentence Scoring-GPT2
Sentence Scoring-OPT
When the model size increases, gpt2 tends to predict more accurate results with smaller ppl. However, opt models (except opt-350m) produce much larger ppl than the ppl of opt-350m.
Besides, it is abnormal that when the model size increases, opt models seem to have larger confidence score about the argmax decoding tokens (check argmax probility above).
I wonder what is causing such an issue. Looking forward to your reply. Thx!
The text was updated successfully, but these errors were encountered: