-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : improve BPE pre-processing + LLaMA 3 and Deepseek support #6920
Conversation
This comment was marked as resolved.
This comment was marked as resolved.
unicode.cpp
Outdated
static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) | ||
{ | ||
// code to convert from utf32/utf16 to utf8 | ||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>, wchar_t> converter; | ||
std::string utf8 = converter.to_bytes(ws); | ||
return utf8; | ||
static inline std::string unicode_wstring_to_utf8(const std::wstring & ws) { | ||
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv; | ||
return conv.to_bytes(ws); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dragnil1 Not sure if this is the intent, but the following change of this function makes the tokenizer tests pass on my Mac. Do you think this is OK to change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change converts UCS-2 or UCS-4/UTF-32 encoded std::wstring
to UTF-8 encoded std::string
and the previous one, converts UTF-16 encoded std::wstring
to UTF-8 encoded std::string
according to reference. Both works on Ubuntu(tested) but I am not sure about windows as it uses UTF-16 encoded std::wstring
.
llama.cpp
Outdated
std::vector<std::string> word_collection; | ||
switch (vocab.type) { | ||
case LLAMA_VOCAB_TYPE_BPE: | ||
switch (vocab.arch) { | ||
// TODO: how to detect deepseek and llama v3 models? | ||
//case LLM_ARCH_LLAMA: | ||
//case LLM_ARCH_DEEPSEEK_CODER: | ||
// word_collection = unicode_regex_split(text, { | ||
// "[\r\n]", | ||
// "\\s?\\p{L}+", | ||
// "\\s?\\p{P}+", | ||
// "[一-龥ࠀ-一가-]+", | ||
// "\\p{N}+" | ||
// }); | ||
// break; | ||
//case LLM_ARCH_DEEPSEEK_LLM: | ||
// word_collection = unicode_regex_split(text, { | ||
// "[\r\n]", | ||
// "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", | ||
// "\\s?[!-/:-~!-/:-~‘-‟ -。]+", | ||
// "\\s+$", | ||
// "[一-龥ࠀ-一가-]+", | ||
// "\\p{N}+" | ||
// }); | ||
// break; | ||
default: | ||
// default regex for BPE tokenization pre-processing | ||
{ | ||
word_collection = unicode_regex_split(text, { | ||
"\\p{P}+", | ||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", | ||
"\\p{N}+", | ||
"[0-9][0-9][0-9]" | ||
}); | ||
} | ||
break; | ||
} | ||
break; | ||
default: | ||
GGML_ASSERT(false); | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the missing part - how to distinguish models from one another?
For example all LLaMA
, Deepseek Coder
and Deepseek LLM
models have the same architecture:
"architectures": [
"LlamaForCausalLM"
],
There seems to be no way to automatically determine which model we are converting. Therefore, there is no way to automatically determine the correct regex to use.
Seems we will have to rely on some heuristics based on the rest of the parameters, such as vocab size and tensor sizes. Not great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I'm new to the inner workings of llama.cpp and get something wrong, but is vocab.arch
coming from the gguf_metadata_kv_t in the gguf?
If it's not coming from there, would it be reasonable to add it as a key in the gguf? Then the file could specify what it needs, and llama.cpp could use that, or otherwise just fallback to the current behavior.
The gguf specification talks about how it "is designed to be unambiguous by containing all the information needed to load a model", and this seems like information needed to load a model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that when creating the GGUF file in the first place (i.e. during conversion from HF to GGUF) there is no way to know what model we are dealing with. For example, take these 2 models:
- LLaMA v3 8B Instruct: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/tree/main
- DeepSeek LLM 7B Chat: https://huggingface.co/deepseek-ai/deepseek-llm-7b-chat/tree/main
Both use LLaMA architecture, both use BPE tokenizer and so currently they will be interpreted as the same arch by llama.cpp
.
However, they use different pre-tokenizers:
LLaMA:
"normalizer": null,
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": false
}
]
},
DeepSeek LLM:
"normalizer": {
"type": "Sequence",
"normalizers": []
},
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "[\r\n]"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "Split",
"pattern": {
"Regex": "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "Split",
"pattern": {
"Regex": "\\s?[!-/:-~!-/:-~‘-‟ -。]+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "Split",
"pattern": {
"Regex": "\\s+$"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "Split",
"pattern": {
"Regex": "[一-龥ࠀ-一가-]+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "Digits",
"individual_digits": true
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": false
}
]
},
So maybe we have to start parsing this information from the tokenizer.json
and use it to determine the correct arch. Not sure yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking more about this, I'm starting to consider the option where we tokenize a few strings during conversion and based on the resulting tokens we add a new enum to the GGUF header indicating the pre-tokenizer type. In llama.cpp
we will have custom implementations of each pre-tokenizer type with a fallback to some default pre-tokenizer (as we already do)
In the convert script, if the strings tokenize to unknown set of tokens, we stop with an error asking the developer to check the pre-tokenizer configuration and either assign an existing one or add a new one to the enum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
option where we tokenize a few strings during conversion
It looks like a pretty messy solution. Maybe it's better to choose a variant with parsing tokenizer.json and make alternative implementation on C++?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a prototype of the idea above:
llama.cpp/convert-hf-to-gguf.py
Lines 379 to 407 in 9b4d63a
def get_vocab_base_pre(self, tokenizer) -> str: | |
# encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that | |
# is specific for the BPE pre-tokenizer used by the model | |
# we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can | |
# use in llama.cpp to implement the same pre-tokenizer | |
chktxt = "\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български what's ''''''```````\"\"\"\"......!!!!!!??????" | |
chktok = tokenizer.encode(chktxt) | |
chkhsh = hash(tuple(chktok)) | |
print(f"chktok: {chktok}") | |
print(f"chkhsh: {chkhsh}") | |
res = None | |
# NOTE: if you get an error here, you need to add the model to the if-elif chain below | |
# observe the stdout for the chkhsh value and add it to the chain | |
if self.model_arch == gguf.MODEL_ARCH.LLAMA: | |
if chkhsh == -3290901550109860290: | |
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer.json | |
res = "llama3" | |
if chkhsh == 4190561703949727616: | |
# ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct/blob/main/tokenizer.json | |
res = "deepseek-coder" | |
if res is None: | |
raise NotImplementedError(f"BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") | |
Feedback is welcome
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I went through each of these steps awhile back and realized there's no way to do it unless I create my own way of doing it. Automating it was just not an option. For detection though, why not just create a hash sum of the encodings as a list? e.g. hash(tuple([k, v for k, v in tokenizer.model.vocab.items()]))
.
Could probably do this for any model file. Only issue is knowing the sum value in advance, which means they'd need to be added manually. This would need to include added tokens and any other required misc files.
This comment was marked as resolved.
This comment was marked as resolved.
Edit: llama3 outdated. Redownload json files.
Edit: my llama3 model was outdated. Fetching latest version made this go away. |
Is phi-2 included in the fixes? It uses |
@BramVanroy I'm about to have my breakfast. I'll add it to my PR #7018 if it isn't merged by then. This is an issue with every model supported by the HF script. They all require a hash and pretokenizer in order to be validated. The quality of the output is degraded otherwise. I had to regen all the converted models I use. Spent last night uploading the ones I care about. |
@teleprint-me I had some users tell me that sometimes generation degrade significantly after a while using ollama. I can't reproduce it on plain Python so I came looking for a potential issue with llamacpp. I've generated the hash in #7022 so you can just copy that, I think. |
@BramVanroy That's-possibly just-an Ollama issue. Model generation on latest llama.cpp is phenomenal. My micro pretrained models putput quality skyrocketed with this PR update for some reason. I think it depends on the model because I tested phi 1, 2, and 3, llama 3, and mistral 7v2 as well as stablelm 1.6. |
Both llama.cpp (b2776) and koboldcpp (1.64) seem to be fine now, but ollama as of 0.1.32 still has tokenizer issues (ollama/ollama#4082). |
Hi, I just want to clarify some things (I'm currently on c4ec9c0). It appears that there are currently two ways to successfully convert HF llama3 models to gguf:
The original pytorch checkpoints from Meta have to be converted to HF as mentioned here: #6819 (I used the script from The tokenizer config in the Meta repo is slightly modified over the raw conversion:
This now creates a total of 4 possible ways to generate a gguf:
The I'm not really up-to-date on this stuff, but I assume (very naively) that the tokenizer changes just shifted around the logic to a different part, so they should be equivalent in the end? Or could this have any meaningful effect on the gguf results? |
I'm not sure the best place to comment this, but the current llama-3 Q_4_M performance compared with Groq seems quite different. Maybe this is due to the quantitation. Can anyone confirm? It's very clear to me as CrewAI won't run with Ollama without additional prompt tweaks, vs just running on Groq. |
@nkeilar Did you try using regular llama cpp to test the prompt ? Since I found that output using server and regular llama cpp is quite different (ollama is like llama.cpp server output) Edit : I compare to groq too and regular llama cpp = groq, server is not |
* merged the changes from deepseeker models to main branch * Moved regex patterns to unicode.cpp and updated unicode.h * Moved header files * Resolved issues * added and refactored unicode_regex_split and related functions * Updated/merged the deepseek coder pr * Refactored code * Adding unicode regex mappings * Adding unicode regex function * Added needed functionality, testing remains * Fixed issues * Fixed issue with gpt2 regex custom preprocessor * unicode : fix? unicode_wstring_to_utf8 * lint : fix whitespaces * tests : add tokenizer tests for numbers * unicode : remove redundant headers * tests : remove and rename tokenizer test scripts * tests : add sample usage * gguf-py : reader prints warnings on duplicate keys * llama : towards llama3 tokenization support (wip) * unicode : shot in the dark to fix tests on Windows * unicode : first try custom implementations * convert : add "tokenizer.ggml.pre" GGUF KV (wip) * llama : use new pre-tokenizer type * convert : fix pre-tokenizer type writing * lint : fix * make : add test-tokenizer-0-llama-v3 * wip * models : add llama v3 vocab file * llama : adapt punctuation regex + add llama 3 regex * minor * unicode : set bomb * unicode : set bomb * unicode : always use std::wregex * unicode : support \p{N}, \p{L} and \p{P} natively * unicode : try fix windows * unicode : category support via std::regex * unicode : clean-up * unicode : simplify * convert : add convert-hf-to-gguf-update.py ggml-ci * lint : update * convert : add falcon ggml-ci * unicode : normalize signatures * lint : fix * lint : fix * convert : remove unused functions * convert : add comments * convert : exercise contractions ggml-ci * lint : fix * cmake : refactor test targets * tests : refactor vocab tests ggml-ci * tests : add more vocabs and tests ggml-ci * unicode : cleanup * scripts : ignore new update script in check-requirements.sh * models : add phi-3, mpt, gpt-2, starcoder * tests : disable obsolete ggml-ci * tests : use faster bpe test ggml-ci * llama : more prominent warning for old BPE models * tests : disable test-tokenizer-1-bpe due to slowness ggml-ci --------- Co-authored-by: Jaggzh <jaggz.h@gmail.com> Co-authored-by: Kazim Abrar Mahi <kazimabrarmahi135@gmail.com>
Continuing the work in #6252 by @dragnil1
This PR adds support for BPE pre-tokenization to
llama.cpp
Summary
The state so far has been that for all BPE-based models,
llama.cpp
applied a default pre-tokenization inherited back from GPT-2:llama.cpp/llama.cpp
Lines 12186 to 12196 in e00b4a8
This works most of the times since BPE models use similar pre-tokenization strategies. However, there are cases where this fails: #6914. This leads to poor generation quality because the model starts to work with out-of-distribution data when the pre-tokenization splits the input string in the wrong way
There are 2 main obstacles in introducing proper BPE pre-tokenization:
<regex>
does not support sophisticated regexes which are typically used for pre-tokenization: llama : improve BPE pre-processing + LLaMA 3 and Deepseek support #6920 (comment)transformers
define it intokenizer.json
https://huggingface.co/docs/transformers/en/main_classes/configurationBoth introducing a dedicated regex library or supporting complex json configurations are out-of-scope for
llama.cpp
. Therefore, this PR implements the following solution:std::regex
andstd::wregex
+ some pre-processing: llama : improve BPE pre-processing + LLaMA 3 and Deepseek support #6920 (comment)Details
Introduce new convert-hf-to-gguf-update.py script
llama.cpp/convert-hf-to-gguf-update.py
Lines 1 to 19 in 120cf37
From now on, we start listing all supported models in it:
llama.cpp/convert-hf-to-gguf-update.py
Lines 47 to 56 in c21ab18
During conversion with
convert-hf-to-gguf.py
, if the hash of the tokens of a large string are not recognized, we prompt for update ofconvert-hf-to-gguf-update.py
llama.cpp/convert-hf-to-gguf.py
Lines 263 to 315 in c21ab18
For now, this is required only for BPE models, since it seems SPM does not use pre-tokenization
The string used for the hashing should be extended to cover as much pre-tokenizer functionality as possible:
llama.cpp/convert-hf-to-gguf-update.py
Lines 37 to 40 in c21ab18
Pre-tokenizer types are identified via a string written to the GGUF header:
llama.cpp/llama.cpp
Line 397 in c21ab18
For each pre-tokenizer, we have to tell
llama.cpp
what pre-processing regexes to use:llama.cpp/llama.cpp
Lines 12087 to 12141 in c21ab18
Here, we have to inspect manually the contents of the
tokenizer.json
of the model and either reuse an existing set of regex patterns, or add a new one corresponding to the new configuration. For a tutorial, see 120cf37. We verify the correctness using thetests/test-tokenizer-0
program and the exported vocab for that model:Old GGUF models using BPE tokenizers, generated before this change, will fallback to the "default" pre-tokenization, which in almost all cases is wrong. A warning is printed in the output:
llama.cpp/llama.cpp
Lines 4333 to 4352 in 80cb312
Although we now support pre-processing using regexes, there is now also infrastructure to add more custom splitting implementations in order to have better performance:
llama.cpp/unicode.cpp
Lines 424 to 432 in c21ab18
For example, there is already an attempt to add custom LLaMA v3 pre-tokenization: llama3 custom regex split #6965
The tokenizer tests have been refactored to allow easy addition of more tests and vocabs. Add tests here and run
convert-hf-to-gguf-update.py
to create input/output files for all known tokenizer models:llama.cpp/convert-hf-to-gguf-update.py
Lines 181 to 225 in c21ab18
llama.cpp/tests/CMakeLists.txt
Lines 68 to 79 in c21ab18
TODOs
Fix custom GPT-2 pre-processing bug:
llama.cpp/unicode.cpp
Lines 430 to 434 in 120cf37
Fix MPT pre-tokenization:
llama.cpp/llama.cpp
Lines 12136 to 12146 in 120cf37