Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abnormal behavior of OPT except OPT-350m #17653

Closed
2 of 4 tasks
hemingkx opened this issue Jun 10, 2022 · 32 comments · Fixed by #17785
Closed
2 of 4 tasks

Abnormal behavior of OPT except OPT-350m #17653

hemingkx opened this issue Jun 10, 2022 · 32 comments · Fixed by #17785
Labels

Comments

@hemingkx
Copy link

hemingkx commented Jun 10, 2022

System Info

- `transformers` version: 4.20.0.dev0
- Platform: Linux-5.11.0-1020-azure-x86_64-with-debian-bullseye-sid
- Python version: 3.7.13
- Huggingface_hub version: 0.7.0
- PyTorch version (GPU?): 1.10.1+cu111 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@patrickvonplaten @LysandreJik

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

def log_probs_with_ppl(path, prompt):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch
    import torch.nn.functional as F
    # for half precision (13b models): torch_dtype=torch.float16
    model = AutoModelForCausalLM.from_pretrained(path).cuda()
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    outputs = model(input_ids, labels=input_ids)
    logits = outputs.logits
    arg_probs, _ = F.softmax(logits, dim=-1).max(-1)
    print("argmax probility:", arg_probs[0].cpu().detach().numpy())
    log_probs, tokens = F.log_softmax(logits, dim=-1).max(-1)
    print("argmax log probability:", log_probs[0].cpu().detach().numpy())
    sent = tokenizer.decode(tokens.squeeze().cpu().detach().numpy(), skip_special_tokens=False)
    print("argmax tokens:", sent)
    xentropy_loss = outputs[0]
    print("cross entropy loss:", xentropy_loss.item())
    ppl = torch.exp(xentropy_loss).item()
    print("ppl:", ppl)


if __name__ == "__main__":
    model_path = 'huggingface/opt-1.3b'
    prompts = "There is a book on the desk."
    log_probs_with_ppl(model_path, prompts)

Expected behavior

The scripts above are used to test both the gpt2 and opt models, the results are shown below:

Sentence Scoring-GPT2

Input: There is a book on the desk.

# gpt2
argmax probility: [0.05637151 0.27859244 0.08230144 0.11579145 0.1521898  0.34015855 0.1640605  0.16998181]
argmax log probability: [-2.8757913 -1.2780054 -2.4973667 -2.1559646 -1.8826268 -1.0783434 -1.8075199 -1.772064 ]
argmax tokens:  is no lot called the subject of It
cross entropy loss: 4.143580913543701
ppl: 63.02811813354492


# gpt2-Medium
argmax probility: [0.39332193 0.26675868 0.08419792 0.1576896  0.2581378  0.1720277 0.1351828  0.13347614]
argmax log probability: [-0.93312687 -1.3214109  -2.474585   -1.8471267  -1.3542618  -1.7600996 -2.0011272  -2.0138326 ]
argmax tokens:  are no lot called the subject in
cross entropy loss: 3.641242504119873
ppl: 38.13919448852539

# gpt2-Large
argmax probility: [0.33576348 0.27546927 0.09161323 0.16216931 0.29808053 0.09624117 0.16370784 0.15139417]
argmax log probability: [-1.0913483 -1.2892792 -2.3901796 -1.8191143 -1.2103915 -2.340898 -1.809672  -1.8878684]
argmax tokens:  are no lot called the subject, It
cross entropy loss: 3.1841206550598145
ppl: 24.146047592163086

Sentence Scoring-OPT

Input: There is a book on the desk.

# opt-125m
argmax probility: [0.00063085 0.00046801 0.00079859 0.00062031 0.00056935 0.00048211 0.00078747 0.00045703 0.00154377]
argmax log probability: [-7.368442  -7.6670203 -7.1326685 -7.3852873 -7.471012  -7.6373434 -7.146683  -7.690766 -6.4735274]
argmax tokens: I aren nothing difference called youtube subjectawaru It
cross entropy loss: 9.004321098327637
ppl: 8138.173828125

# opt-350m
argmax probility: [0.09890626 0.29053134 0.30989376 0.05687688 0.31782693 0.24106818 0.15811151 0.12125468 0.2616786 ]
argmax log probability: [-2.313583  -1.2360439 -1.1715258 -2.8668664 -1.1462483 -1.4226755 -1.8444548 -2.109862 -1.3406382]
argmax tokens: 's a lot called the subject in
cross entropy loss: 3.714618682861328
ppl: 41.042930603027344

# opt-1.3b
argmax probility: [0.18575612 0.51934767 0.6326897  0.51414996 0.97731984 0.9402624 0.63661176 0.20046458 0.6865138 ]
argmax log probability: [-1.6833206  -0.65518177 -0.4577752  -0.6652403  -0.0229413  -0.06159634 -0.45159534 -1.6071177  -0.37612897]
argmax tokens: I are no difference called Amazon subject beside It
cross entropy loss: 7.282720565795898
ppl: 1454.94091796875

# opt-6.7b
argmax probility: [0.9414 0.391  0.766  0.627  0.998  0.7646 0.978  0.4473 0.8735]
argmax log probability: [-0.0602   -0.939    -0.2664   -0.4668   -0.001997 -0.268    -0.02211 -0.8047   -0.135   ]
argmax tokens: I's no lot called this subject. 
cross entropy loss: 7.17578125
ppl: 1307.0

When the model size increases, gpt2 tends to predict more accurate results with smaller ppl. However, opt models (except opt-350m) produce much larger ppl than the ppl of opt-350m.

Besides, it is abnormal that when the model size increases, opt models seem to have larger confidence score about the argmax decoding tokens (check argmax probility above).

I wonder what is causing such an issue. Looking forward to your reply. Thx!

@hemingkx hemingkx added the bug label Jun 10, 2022
@hemingkx
Copy link
Author

Related issue #17545

@patrickvonplaten
Copy link
Contributor

Hmm, that's a very good test - it indeed seems like there is a bug with the model weights conversion... @stephenroller can you by any chance run those models with metaseq to verify that ppl is reasonable?

@patrickvonplaten
Copy link
Contributor

Also @hemingkx, I can confirm your results - just sent a mail to the authors since we think the problem lies in the model conversion

@hemingkx
Copy link
Author

Thanks for your early reply! Hope this can be resolved soon 😊.

@stephenroller
Copy link

This is in my queue, just have a long queue.

@StellaAthena
Copy link
Contributor

@stephenroller Big Science was hoping to compare the BLOOM model to OPT. Do you have any idea if it’s reasonable to expect that this is resolved by the end of the month?

I would be very appreciative if you could keep me apprised of the progress.

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 16, 2022

@stephenroller do you use Tensor Parallelism for OPT models >= 1.3b? (I can't find the information on the paper)
I faced a strange behaviour when trying to convert BLOOM model trained on Megatron-LM. It might be related if so but not sure..
Related issue: pytorch/pytorch#76232

@stephenroller
Copy link

We used tensor parallelism for all models EXCEPT 350 :P

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 16, 2022

@patrickvonplaten then it might explain these numbers on all models except 350m ?
I remember trying a quick test on BLOOM-176b with this hack and without the hack and it made quite a difference quantitatively (logits exactness) but when I qualitatively compared the generation results it didn't made any difference.
@stephenroller what was the TP rank used for these models just out of curiosity?

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 16, 2022

Also do you remember which layer did you TP-ranked?

@StellaAthena
Copy link
Contributor

EleutherAI has found that merging TP parallel models can be extremely non-intuitive, and it took some substantial work from @zphang to figure out how to do it while maintaining performance. You can find code for merging the TP parallel shards for our 20B model here and I think we have some more general code as well but I would have to hunt it down to release it.

@stephenroller
Copy link

@patrickvonplaten then it might explain these numbers on all models except 350m ? I remember trying a quick test on BLOOM-176b with this hack and without the hack and it made quite a difference quantitatively (logits exactness) but when I qualitatively compared the generation results it didn't made any difference. @stephenroller what was the TP rank used for these models just out of curiosity?

https://github.com/facebookresearch/metaseq/blob/cf24413b2c78ad2f293fb9ac53a74be20f087863/metaseq/launcher/opt_job_constants.py#L32-L44

The final field in each of these (except 350M, where i trained it in a different one-off pipeline using MP1)

@younesbelkada
Copy link
Contributor

Cool thanks!

@patrickvonplaten
Copy link
Contributor

That's very interesting! Thanks for all the pointers here @StellaAthena @stephenroller @younesbelkada - let's try out whether this fixes it :-)

@thomasw21
Copy link
Contributor

Okay I think we've fixed the issue, I think it was caused because of poor conversion on our side where we missed those lines from metaseq: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L466-L477

Fix is in #17785 , preliminary tests on 125m and 1B3 showed that the fix significantly reduces the ppl.

>>> model_path="fixed_opt_125m"
>>> prompt="Hello my name is"
>>> log_probs_with_ppl(model_path, prompt)
Input torch.Size([1, 5])
Logits torch.Size([1, 5, 50272])
torch.return_types.max(
values=tensor([[0.2398, 0.2326, 0.3332, 0.9363, 0.0097]], grad_fn=<MaxBackward0>),
indices=tensor([[ 100,    6,  766,   16, 1236]]))
argmax probility: [[0.23982257 0.23258895 0.33315504 0.9362957  0.00967377]]
argmax log probability: [[-1.4278558  -1.4584825  -1.0991473  -0.06582398 -4.6383367 ]]
argmax tokens: I, name is j
cross entropy loss: 4.051314830780029
ppl: 57.47297286987305

@younesbelkada
Copy link
Contributor

Hi @stephenroller ! We would like to have a final check about something while merging #17785 , does all OPT models have share_input_output_embed set to True ? We know this is set to True for opt-350 but we are not sure about other models https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L486

@stephenroller
Copy link

Yes, all models have that True.

Yay, glad you were able to fix it!

@guillaumekln
Copy link
Contributor

guillaumekln commented Jun 22, 2022

Hi,

Thanks everyone for working on this issue!

I'm not sure the issue is fully resolved. I found that the weight of the final layer norm is one-initialized, which is the default initialization in nn.LayerNorm. The bias seems fine, however, meaning it is not zero-initialized.

>>> from transformers import OPTModel
>>> model = OPTModel.from_pretrained("facebook/opt-13b")
>>> all(model.decoder.final_layer_norm.weight == 1)
True
>>> all(model.decoder.final_layer_norm.bias == 0)
False

Same observation with the variants 2.7B and 6.7B. This seems unexpected. Is it possible the final layer norm weight was lost at some point during the conversion?

@stephenroller
Copy link

We observed that models 13B and larger ended up learning layer norm weights of 1; we chalked it up to (1 + epsilon) precision issues. Can you check other layer norms for the same set of models?

@guillaumekln
Copy link
Contributor

I checked with 2.7B which is faster to load and all layer norm weights are indeed 1:

>>> from transformers import OPTModel
>>> model = OPTModel.from_pretrained("facebook/opt-2.7b")
>>> all(all(layer.final_layer_norm.weight == 1) for layer in model.decoder.layers)
True
>>> all(all(layer.self_attn_layer_norm.weight == 1) for layer in model.decoder.layers)
True

So this seems to be the expected value based on your comment. Thanks @stephenroller!

@thomasw21
Copy link
Contributor

@patrickvonplaten IMO we should double check that, this seems like a highly unlikely thing for weights ...

@thomasw21
Copy link
Contributor

This doesn't happen for 125m:

>>> from transformers import OPTModel
>>> model = OPTModel.from_pretrained("facebook/opt-125m")
>>> all(all(layer.final_layer_norm.weight == 1) for layer in model.decoder.layers)
False
>>> all(all(layer.self_attn_layer_norm.weight == 1) for layer in model.decoder.layers)
False
>>> all(model.decoder.final_layer_norm.weight == 1)
False

So I'd expect this not to be an issue with conversion scripts of whatever .... Were you able to generate using 13b? Are the generations not as good as you expected?

@guillaumekln
Copy link
Contributor

I'm working on a project where we apply 8-bit quantization to the OPT linear weights. It works well for 125m where we get the same output with and without quantization, but there are unexpected repetitions with the quantized 350m and larger variants. (OpenNMT/CTranslate2#818).

I found it odd that the models we have some issues with are precisely the models where all layer norm weights are 1 (so far I verified this to be true for 350m, 1.3b, 2.7b, 6.7b, and 13b). As you said this is generally unlikely for weights so I thought this is worth mentioning.

@thomasw21
Copy link
Contributor

350m didn't have issues normally. I agree it's very unlikely, we're going to run a perplexity test to check that it is fairly close to gpt2. @stephenroller did you make it untrainable and just set it at 1?

@stephenroller
Copy link

We didn't make it untrainable. We actually spent some time digging into this and had a quite vigorous internal debate on whether this was a "problem." Again, since we trained with memory efficient fp16 with often low LRs, my believe was that the gradients were experiencing (1 + epsilon) = 1 underflow type issues; epsilon is much higher than you might expect:

torch>>> torch.finfo(torch.float16)
finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, tiny=6.10352e-05, dtype=float16)
>>> torch.finfo(torch.bfloat16)
finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, tiny=1.17549e-38, dtype=bfloat16)

However, to my knowledge, this weight=1 issue didn't crop up until 13B scale... Let me see if I have some human readable checkpoints.

@patrickvonplaten
Copy link
Contributor

To add to this, I just evaluated the perplexity on the corrected OPT checkpoints (the ones that are online on the Hub now) with the following script (slighly adapted from the original one):

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import torch.nn.functional as F


def log_probs_with_ppl(path, prompt):
    model = AutoModelForCausalLM.from_pretrained(path)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)

    logits = outputs.logits
    arg_probs, _ = F.softmax(logits, dim=-1).max(-1)
    print("argmax probility:", arg_probs[0].cpu().detach().numpy())
    log_probs, tokens = F.log_softmax(logits, dim=-1).max(-1)
    print("argmax log probability:", log_probs[0].cpu().detach().numpy())
    sent = tokenizer.decode(tokens.squeeze().cpu().detach().numpy(), skip_special_tokens=False)
    print("argmax tokens:", sent)
    xentropy_loss = outputs[0]
    print("cross entropy loss:", xentropy_loss.item())
    ppl = torch.exp(xentropy_loss).item()
    print("ppl:", ppl)


if __name__ == "__main__":
    prompts = "There is a book on the desk."

    for model_id in ["opt-125m", "opt-350m", "opt-1.3b", "opt-2.7b", "opt-6.7b", "opt-13b", "opt-30b"]:
        print(20 * "=" + model_id + 20 * "=")
        model_path = os.path.join("facebook", model_id)
        log_probs_with_ppl(model_path, prompts)

and the results look as follows (a bit weird that the 30B ppl is not lower)

====================opt-125m====================
argmax probility: [0.23981844 0.31434286 0.32221574 0.05977888 0.34688717 0.14583494
 0.21301866 0.16133878 0.17206082]
argmax log probability: [-1.4278731 -1.157271  -1.132534  -2.817103  -1.0587558 -1.9252799
 -1.5463755 -1.8242489 -1.7599072]
argmax tokens: I's a lot called the subject that

cross entropy loss: 3.867582321166992
ppl: 47.82661819458008
====================opt-350m====================
argmax probility: [0.09871461 0.2903144  0.3098579  0.05694651 0.31773096 0.24076065
 0.15807256 0.12118834 0.26144707]
argmax log probability: [-2.3155224 -1.2367908 -1.1716415 -2.8656428 -1.1465503 -1.423952
 -1.844701  -2.1104095 -1.3415234]
argmax tokens:
's a lot called the subject in

cross entropy loss: 3.7145543098449707
ppl: 41.04029083251953
====================opt-1.3b====================
argmax probility: [0.10844591 0.2721116  0.30265862 0.03994901 0.42767116 0.21947733
 0.21639532 0.16303496 0.27209064]
argmax log probability: [-2.2215037 -1.301543  -1.1951498 -3.2201512 -0.8494007 -1.5165063
 -1.5306484 -1.8137906 -1.3016201]
argmax tokens: I's a lot called the subject.

cross entropy loss: 3.5381228923797607
ppl: 34.40228271484375
====================opt-2.7b====================
argmax probility: [0.10350165 0.29980636 0.3279065  0.04002884 0.3831317  0.17393681
 0.06864104 0.12390347 0.3078983 ]
argmax log probability: [-2.2681677 -1.2046186 -1.1150268 -3.2181551 -0.9593765 -1.7490631
 -2.6788647 -2.0882525 -1.1779857]
argmax tokens: I's a lot called the subject in

cross entropy loss: 3.407679557800293
ppl: 30.195096969604492
====================opt-6.7b====================
argmax probility: [0.10619629 0.29815957 0.3240549  0.04175518 0.38586977 0.20265782
 0.14770415 0.18028003 0.15865195]
argmax log probability: [-2.2424662  -1.2101265  -1.1268424  -3.1759317  -0.95225537 -1.5962363
 -1.912544   -1.713244   -1.8410425 ]
argmax tokens: I's a lot called the subject.

cross entropy loss: 3.3324668407440186
ppl: 28.007347106933594
====================opt-13b====================
argmax probility: [0.11410075 0.24206492 0.32771447 0.04382524 0.42932945 0.19955206
 0.09738682 0.18719622 0.23314796]
argmax log probability: [-2.1706734 -1.4185493 -1.1156125 -3.1275454 -0.8455307 -1.6116802
 -2.3290644 -1.6755979 -1.456082 ]
argmax tokens: I is a lot called the shelf.

cross entropy loss: 3.196335792541504
ppl: 24.44280242919922
====================opt-30b====================
argmax probility: [0.05695914 0.21338376 0.29346094 0.05644348 0.14937086 0.306213
 0.14233655 0.19602104 0.15788034]
argmax log probability: [-2.865421  -1.5446631 -1.2260108 -2.8745155 -1.9013231 -1.1834744
 -1.949561  -1.6295333 -1.8459179]
argmax tokens: _ is a bug called the table in

cross entropy loss: 3.548802375793457
ppl: 34.77164840698242

@stephenroller
Copy link

30B looks kinda weird with its higher ppl than the rest...

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 22, 2022

Yeah - tried out another prompt prompts = "In its most general sense, the term 'world' refers to the totality of entities, to the whole of reality or to everything that is.":

====================opt-125m====================
argmax probility: [0.23982129 0.15632838 0.13140368 0.75963086 0.27387255 0.878215
 0.3593415  0.02516824 0.29971686 0.01008662 0.16188805 0.38820833
 0.94960755 0.3430672  0.10371881 0.9371656  0.35841522 0.19422102
 0.10693377 0.26196504 0.09036761 0.27197453 0.3381063  0.48046958
 0.20347297 0.47123003 0.20022035 0.25318944 0.08799087 0.22637637]
argmax log probability: [-1.4278613  -1.8557965  -2.0294812  -0.2749227  -1.2950925  -0.12986381
 -1.0234821  -3.6821725  -1.2049171  -4.5965457  -1.8208503  -0.9462132
 -0.05170649 -1.0698289  -2.2660718  -0.06489524 -1.0260631  -1.6387584
 -2.2355456  -1.3395443  -2.4038694  -1.3020469  -1.0843949  -0.73299134
 -1.5922221  -0.7524089  -1.6083368  -1.3736173  -2.4305222  -1.4855564 ]
argmax tokens: I the first recent form, the new "re- is to the world of the, not the extent of the. the the that is not

cross entropy loss: 3.4385933876037598
ppl: 31.14312171936035
====================opt-350m====================
argmax probility: [0.09871422 0.16178179 0.08027261 0.35097533 0.519601   0.88294435
 0.1850073  0.03688832 0.36755642 0.01141822 0.14522597 0.33600938
 0.9539717  0.33490822 0.17595953 0.95351464 0.3859613  0.17998883
 0.12469297 0.28728586 0.07168846 0.28822082 0.35224074 0.44462964
 0.22922419 0.33133236 0.25537238 0.7224368  0.147327   0.15977049]
argmax log probability: [-2.3155262  -1.8215069  -2.522327   -1.0470394  -0.65469414 -0.12449309
 -1.6873599  -3.2998602  -1.0008785  -4.4725447  -1.9294643  -1.0906162
 -0.04712127 -1.0938988  -1.7375013  -0.0476005  -0.95201814 -1.7148604
 -2.0819008  -1.2472775  -2.6354256  -1.2440283  -1.0434405  -0.8105136
 -1.4730548  -1.1046333  -1.3650324  -0.32512537 -1.9151007  -1.8340169 ]
argmax tokens:
 the first recent sense, the term "s' is to the world of the, including the totality of the. to the in exists.

cross entropy loss: 2.962890148162842
ppl: 19.35382652282715
====================opt-1.3b====================
argmax probility: [0.10844579 0.15118013 0.06978295 0.6313941  0.4198188  0.8240803
 0.17271882 0.16285989 0.27022356 0.01286044 0.24925965 0.28614777
 0.9725371  0.33411652 0.12802885 0.9795981  0.24178138 0.15681867
 0.05040615 0.4303817  0.07268634 0.3975281  0.31267446 0.45626673
 0.3057922  0.5389601  0.4383189  0.4491574  0.09853235 0.14064014]
argmax log probability: [-2.221505   -1.8892832  -2.6623657  -0.45982504 -0.86793214 -0.19348735
 -1.7560903  -1.814865   -1.3085057  -4.3535995  -1.3892602  -1.2512469
 -0.02784706 -1.0962654  -2.0554996  -0.0206129  -1.4197214  -1.8526651
 -2.987642   -0.84308285 -2.6216018  -0.92248964 -1.1625926  -0.7846777
 -1.1848495  -0.61811376 -0.8248085  -0.8003819  -2.3173704  -1.9615508 ]
argmax tokens: I the latest recent sense, the term "social' is to the universe of all that or which totality of the. to the that exists.

cross entropy loss: 2.8560848236083984
ppl: 17.393295288085938
====================opt-2.7b====================
argmax probility: [0.10350163 0.15577659 0.08376333 0.5673112  0.50874865 0.8624449
 0.21431658 0.15363932 0.4019405  0.01176295 0.31341943 0.22355694
 0.9744133  0.4234775  0.14364612 0.9677136  0.25195396 0.16395077
 0.05868356 0.2550737  0.1175658  0.33331287 0.24672066 0.46241838
 0.15379134 0.5719755  0.31928268 0.480221   0.26081583 0.18780598]
argmax log probability: [-2.268168   -1.8593324  -2.47976    -0.5668472  -0.6758013  -0.14798404
 -1.5403011  -1.8731475  -0.9114513  -4.442801   -1.1602129  -1.4980892
 -0.02591975 -0.8592549  -1.9404025  -0.03281909 -1.3785089  -1.808189
 -2.8355956  -1.3662028  -2.140757   -1.0986737  -1.3994985  -0.7712852
 -1.8721585  -0.5586591  -1.1416783  -0.7335089  -1.3439407  -1.6723459 ]
argmax tokens: I the first recent sense, a term �social' is to the universe of all that including which universe of the. existence the that exists. In
cross entropy loss: 2.7184252738952637
ppl: 15.1564359664917
====================opt-6.7b====================
argmax probility: [0.1061962  0.15520018 0.08349195 0.41061476 0.4894317  0.85579187
 0.19358145 0.16583215 0.3086383  0.0087592  0.25194162 0.22671384
 0.9799279  0.39524195 0.22661391 0.958905   0.28948593 0.23172891
 0.08076624 0.20174423 0.24156563 0.29045665 0.19864736 0.5106398
 0.2697726  0.5609671  0.46792567 0.72609276 0.34715182 0.18653922]
argmax log probability: [-2.242467   -1.8630395  -2.483005   -0.8900999  -0.7145103  -0.15572807
 -1.642057   -1.7967792  -1.1755853  -4.737651   -1.3785579  -1.4840667
 -0.0202763  -0.92825717 -1.4845076  -0.04196331 -1.2396486  -1.462187
 -2.5161963  -1.6007546  -1.4206141  -1.236301   -1.616224   -0.6720909
 -1.3101759  -0.578093   -0.75944585 -0.32007748 -1.057993   -1.6791137 ]
argmax tokens: I the first recent sense, the term �f' refers to the totality of all that including the totality of the. existence the that exists. In
cross entropy loss: 2.5715460777282715
ppl: 13.086040496826172
====================opt-13b====================
argmax probility: [0.11410033 0.18445235 0.07394046 0.3377592  0.4820613  0.89032376
 0.26655433 0.22948444 0.3482319  0.00974461 0.36039528 0.21514857
 0.9776878  0.32965162 0.12019254 0.9643266  0.2866262  0.18254285
 0.09594422 0.30989775 0.1353004  0.32887647 0.18270463 0.42996788
 0.2804617  0.59530467 0.37093756 0.7375983  0.372651   0.17102638]
argmax log probability: [-2.1706772  -1.6903641  -2.604495   -1.085422   -0.729684   -0.11617013
 -1.3221772  -1.4719201  -1.0548866  -4.631041   -1.0205538  -1.5364265
 -0.02256491 -1.1097189  -2.1186602  -0.03632521 -1.2495763  -1.7007704
 -2.3439884  -1.1715128  -2.000258   -1.1120731  -1.6998845  -0.84404474
 -1.2713182  -0.518682   -0.9917216  -0.30435592 -0.98711294 -1.7659374 ]
argmax tokens: I the current basic sense, a term �b' refers to the entire of all that both which totality of the. existence the that exists. In
cross entropy loss: 2.5906801223754883
ppl: 13.33884048461914
====================opt-30b====================
argmax probility: [0.05695887 0.12192544 0.118644   0.28989854 0.61544293 0.8780987
 0.21866173 0.10945266 0.18478885 0.01414752 0.61560327 0.3003071
 0.9654389  0.39148605 0.2840576  0.96917915 0.23066148 0.18086651
 0.13592616 0.27291402 0.18672818 0.301525   0.23815584 0.5115922
 0.41115332 0.5572077  0.40895483 0.7513009  0.50103414 0.13948028]
argmax log probability: [-2.865426   -2.1043456  -2.1316278  -1.2382243  -0.48541304 -0.12999624
 -1.5202293  -2.212263   -1.6885414  -4.258216   -0.48515254 -1.2029496
 -0.03517244 -0.9378054  -1.2585783  -0.0313058  -1.4668041  -1.7099961
 -1.9956435  -1.2985984  -1.6781013  -1.1989024  -1.4348301  -0.67022747
 -0.8887891  -0.58481723 -0.8941506  -0.28594905 -0.691081   -1.9698321 ]
argmax tokens: _format constructor basic form, the term "in' refers to the totality of all that events which totality of the. existence the that exists. In
cross entropy loss: 2.6525726318359375
ppl: 14.190498352050781

30B still a bit higher. But note that all those are also run on CPU (currently don't have access to a big GPU)

@patrickvonplaten
Copy link
Contributor

(sorry - there was one final problem with opt-30b, there was a typo with the begin token for 30B see: https://huggingface.co/facebook/opt-30b/commit/a7fad6ce41655b751249f8801cc3a24ede359d31).

Updated results make more sense IMO:

For "There is a book on the desk."

argmax probility: [0.10173301 0.2815346  0.34502357 0.04142236 0.4594141  0.24795197
 0.13434087 0.19007264 0.17950189]
argmax log probability: [-2.2854035  -1.2674999  -1.0641425  -3.1839345  -0.77780336 -1.3945202
 -2.0073748  -1.660349   -1.7175696 ]
argmax tokens: I's a lot called this subject in

cross entropy loss: 3.267496109008789
ppl: 26.245540618896484

For "In its most general sense, the term 'world' refers to the totality of entities, to the whole of reality or to everything that is."

====================opt-30b====================
argmax probility: [0.10173287 0.19473663 0.07857301 0.41664603 0.40225816 0.9184978
 0.24668908 0.32442567 0.34429762 0.01023421 0.26163322 0.21400526
 0.96574026 0.44348246 0.1901864  0.9594519  0.1832267  0.18785793
 0.11864074 0.35639173 0.21181443 0.32442087 0.2604604  0.51606554
 0.4302902  0.56918555 0.3640891  0.7693961  0.49747008 0.15007249]
argmax log probability: [-2.2854047  -1.6361073  -2.543727   -0.87551826 -0.9106612  -0.0850158
 -1.3996265  -1.1256989  -1.0662489  -4.5820193  -1.3408117  -1.5417547
 -0.03486039 -0.81309706 -1.6597506  -0.04139308 -1.6970311  -1.6720693
 -2.1316555  -1.0317248  -1.5520447  -1.1257137  -1.3453044  -0.66152155
 -0.84329545 -0.56354874 -1.0103567  -0.26214936 -0.6982199  -1.8966368 ]
argmax tokens: I the current basic sense, a present �b' refers to the totality of all that events which totality of the. existence the that exists. In
cross entropy loss: 2.567533493041992
ppl: 13.033637046813965

@thomasw21
Copy link
Contributor

Awesome @patrickvonplaten ! So we assume the conversion was correct in that sense I guess.

@patrickvonplaten
Copy link
Contributor

OPT-66B is up as well. Ran the above prompts only on GPU with fp16 so slightly less precision.

For "There is a book on the desk.":

====================opt-66b====================
argmax probility: [0.118   0.287   0.3286  0.04562 0.39    0.2258  0.0776  0.1586  0.177  ]
argmax log probability: [-2.137 -1.248 -1.113 -3.088 -0.942 -1.488 -2.557 -1.842 -1.731]
argmax tokens: I's a lot called this subject in

cross entropy loss: 3.240234375
ppl: 25.546875

For "In its most general sense, the term 'world' refers to the totality of entities, to the whole of reality or to everything that is."

====================opt-66b====================
argmax probility: [0.11804 0.1587  0.09045 0.379   0.3762  0.9214  0.2576  0.421   0.3442
 0.012   0.2072  0.2379  0.982   0.4968  0.155   0.9736  0.1715  0.1682
 0.0908  0.3071  0.299   0.538   0.4558  0.5166  0.8755  0.526   0.736
 0.919   0.408   0.2372 ]
argmax log probability: [-2.137   -1.841   -2.402   -0.9707  -0.9775  -0.0817  -1.356   -0.865
 -1.066   -4.42    -1.574   -1.436   -0.01802 -0.699   -1.864   -0.02663
 -1.763   -1.782   -2.398   -1.181   -1.207   -0.62    -0.7856  -0.6606
 -0.1332  -0.6426  -0.307   -0.0847  -0.8965  -1.438  ]
argmax tokens: I the first recent sense, a term �f' is to the totality of all, events which totality of reality. existence the that exists. In
cross entropy loss: 2.630859375
ppl: 13.8828125

@stephenroller
Copy link

That surprises me less. The 66B was actually harder to train than the 175B, weirdly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants