Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[NeuralChat] Use unk token instead of eos token #1198

Merged
merged 1 commit into from
Jan 29, 2024
Merged

[NeuralChat] Use unk token instead of eos token #1198

merged 1 commit into from
Jan 29, 2024

Conversation

dillonalaird
Copy link
Contributor

Type of Change

Bug fix

Description

In the train.py code the pad_token is set to the eos_token here. Because eos_token's are present in every conversation when we calculate the total_len of a conversation here we basically ignore one token per conversation int(target.ne(tokenizer.pad_token_id).sum()) == int(target.ne(tokenizer.eos_token_id).sum()) so we have to add in the extra counts by adding + len([rou for rou in rounds if rou != ""]). This in itself isn't bad, but if we use another model, say teknium/OpenHermes-2.5-Mistral-7B, they actually use a different eos_token here and so the total_len calculation ends up being wrong:

ipdb> p total_len
126
ipdb> p tokenizer.pad_token_id
32000
ipdb> p target
tensor([    1,   330, 10706,  1444,   264, 13903,  2930,   304,   396, 18278,
        10895, 13892, 28723,   415, 13892,  5212, 10865, 28725, 10537, 28725,
          304, 27057, 11194,   298,   272,  2930, 28742, 28713,  4224, 28723,
         1247, 28747, 28705,  -200, 28705,    13,  3195,   460,   272,  9304,
          302,   272,  1579,   297,   272,  3469, 28804, 21631, 28747,   415,
         1579,   297,   272,  3469,   349,  3075,   304,  2760, 28723,     2,
         1247, 28747,  1824,  4480,   541,   347,  2598,   356,   272,   852,
          302,   272,  1579, 28804, 21631, 28747,   415,   852,   302,   272,
         1579,  4190,   396, 21662,  1116, 28723,     2,  1247, 28747,  1691,
          272,  1579,  7810,  1060,   272,  5948,   442,  5822,   805,   298,
          272,  2081, 28804, 21631, 28747,   415,  1579,   349,  7810,  1060,
          272,  5948, 28725,   690,   349, 22558,   395,   905,   304,   799,
        11999, 28723,     2])
ipdb> p int(target.ne(0).sum())
123

where 123 is the correct total_len and 0 is typically the unk token id. This causes the mismatch warning to get thrown here. A fix for this is to use the unk token like the original LLaVA code uses here and then change the total_len calculation to the original implementation here.

Expected Behavior & Potential Risk

the expected behavior that triggered by this PR

How has this PR been tested?

how to reproduce the test (including hardware information)

Dependency Change?

any library dependency introduced or removed

Signed-off-by: Dillon Laird <dillonalaird@gmail.com>
@hshen14 hshen14 merged commit 6387a0a into intel:main Jan 29, 2024
9 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants