-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add reranking support #9510
Conversation
!! I'm a big fan of the BGE embedding models (they're incredibly user-friendly to fine-tune on ones' own datasets) -- I'm really happy to see support being added for this! I'll definitely take a look and review. |
Initial working version using the In the meantime, will add a |
@ggerganov |
Hopefully sometime this week |
Dear @ggerganov, I highly appreciate the reranking support. I would like to suggest adding support for the gte-multilingual-reranker-base model, which employs the 'NewForSequenceClassification' architecture. This addition would significantly enhance multilingual processing. Thank you! |
Benchmarking different rerankers lower index = better
https://github.com/user-attachments/files/17119411/results_1727199777.4010012.xlsx |
@ExtReMLapin How is this benchmark performed? |
You're right, I forgot to specify how we use it, so what benchmark we used. At the office, as we only need rag we need embeddings and rerankers only to perform "needle in the haystack" challenge.
The only thing we care about is answering "can this question be answered with this text chunk", from our Point of view, rerankers are just upgraded embeddings models (it's technically different, I'm just talking about the final purpose). Please note the questions-answer pairs are VERY similars, this is why we expects the rerankers to have a final score close to ZERO. We considers others rerankers to be unusable. Again, this is a benchmark we built in two hours just to fit our needs. If anyone want to take a look at the code, here it is : evaluate_rerankers.py.txt |
@ExtReMLapin Thanks for the benchmark. I wonder if you could add bce reranker (https://huggingface.co/maidalun1020/bce-reranker-base_v1). It's an interesting one because it claims to have "meaningful rerank score". See their repo for more information: https://github.com/netease-youdao/BCEmbedding Edit: oops, I didn't realize it doesn't support the other three languages except English, in your table :( |
ggml-ci
5b6468f
to
62a45d1
Compare
If a reranker is not even succeding at the needing in the haystack challenge better than bge-m3 embeddings, it's literally not worth our disk space. For reference, int he EXACT same benchmark, but we use embeddings models instead, everage score on BGE-M3 is ~4.5 So a reranker here is doing pretty much WORSE than embeddings 🤡 I'm re-running another test and I added 3 fanfictions pdf and one question-pair set where the link is much more subtile between question and answer. |
I finished running more tests using the new dataset For reference, in the first needle in the haystack challenge, the searched text (needle/answer) was very similar to the question (query). Questions-pairs were like : {
"question": "When was Peter Donkey Born ?",
"needles": [
"Peter Donkey was born in november in 1996",
"P. Donkey was born in 1996",
"Peter Donkey est né en novembre 1996",
"Peter Donkey ese nacio en 1996",
],
},
{
"question": "What is the height of Mount Everest?",
"needles": [
"Mount Everest measures 8,848 meters above sea level.",
"The tallest mountain is 8,848 meters high.",
"La montagne la plus haute mesure 8 848 mètres, c'est l'Everest.",
"La montaña más alta mide 8,848 metros.",
"Der höchste Berg der Welt ist 8.848 Meter hoch.",
],
},
{
"question": "Who invented the telephone?",
"needles": [
"Alexander Graham Bell is credited with the invention of the telephone.",
"The telephone was first patented by Bell in 1876.",
"Le téléphone a été inventé par Alexander Graham Bell.",
"El teléfono fue inventado por Alexander Graham Bell.",
"Das Telefon wurde von Alexander Graham Bell erfunden.",
],
}, You didn't even need a reranker to find the answer as it had similar words I added a new dataset called subtle which is more like this : {
"question": "When did Peter eat a fruit ?",
"needles": [ #link is fruit -> apple
"Right after he went to the gym, he ate an apple.",
],
},
{
"question": "What did the criminal do to get in jail ?",
"needles": [ # link is jail -> emprisoned
"He's emprisoned because he stole a car.",
],
},
{
"question": "What did the doctor prescribe to the patient ?",
"needles": [ #link is doctor/patient -> hospital
"Back from the hospital, he got penicilin.",
],
},
{
"question": "What did the teacher give to the student ?",
"needles": [ #link is teacher/student -> school
"At school, he received a book.",
],
},
{
"question": "What is used to quench thirst?",
"needles": [ #link is thirst -> drink
"After the long walk, he drank a glass of water.",
],
},
|
No, I think it's expected to have some small VRAM usage. You probably want to deploy 2 For the CUDA instance, make sure to offload all layers and enable flash attention: |
You can also set the |
I tried, but I can't find a trivial way to separate the reranker part of ChatLLM so I don't really know how to get the result to compare. Will try again when I have more time. |
@ggerganov sadly that, after disabling |
@rujialiu, there are easy-to-use Python bindings: |
Oops. I forgot to check python bindings. Thanks! And a (hopefully) easier way to debug, is to trace through the official code of bge: from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
score = reranker.compute_score([['A red apple', 'A llama in the garden'], ['A red apple', 'I want some fruit']])
print(score) # [-8.7265625, -0.10467529296875] |
I looked deeper into this. After printing tokenization result, I noticed a difference. llama.cpp:
However, FlagEmbedding's output (I modified its source code to print tokenizer result):
So the BOS of the second sequence is 2 instead of 0. I don't know why, but if I change llama.cpp's code to enforce it to be 0 like this: for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
inp[5] = 2; // <--- just for troubleshooting Then it can get very similar result: -8.797 (vs tei's -8.773)! I've also tested some other equal-length pairs like this: score = reranker.compute_score([['A red apple', 'A llama in the garden'], ['A red apple', 'I want some fruit!'], ['A red apple', 'Another nice looking big apple']])
print(score) # [-8.7265625, -1.53125, 2.203125] llama.cpp's output:
That's very close! EDIT: I made a mistake in my original post. After changing second setence's BOS to 2, actually everything is correct now. So we probably just need:
|
Nice find! So it really depends on what is the correct formatting of the input. The approach that we have implemented in Adding sigmoid option is easy. |
I'm curious too. So I searched in class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
...
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs,
): So they're adding |
Also, for impl Default for RobertaProcessing {
fn default() -> Self {
Self {
sep: ("</s>".into(), 2),
cls: ("<s>".into(), 0),
trim_offsets: true,
add_prefix_space: true,
}
}
} see https://docs.rs/tokenizers/latest/src/tokenizers/processors/roberta.rs.html#16-25 |
Ok, it's clear now - we need to use:
|
After reading more codes in def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLM-RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep So it's 3 |
It's my fault: the pattern in #8555 (comment)) is wrong. Fixed now. After the tokenizer got fixed, I think the results of chatllm and infinity/tei matched now. |
Gently ping @ggerganov for the sigmoid option |
Seems like a good first issue for new contributors. Feel free to create one. |
Ok, I'll try next week |
I ended up adding sigmoid in |
We can leave it up to the client to apply sigmoid if they need to. It's a trivial operation and don't think there is any need to do it server-side. But we can add this option nevertheless. We just have to keep the existing behavior as default and optionally enable the sigmoid per-request. |
You're right. I changed that because I want it to be a painless (almost drop-in) replacement of |
@QuintinShaw we ran more tests (real life tests) at the office. We took the second book of harry potter, and used a sigmoid to give us a 0-1 score for the reranker. Alibaba reranker is terrible compared to BGE reranker. On the left, you can see sorted scores of each chunks (80) compared to the query "Which animal lives in the Chamber of Secrets ?" On the right, same but for BGE M3 V2 Alibaba doesn't highlights any chunk, while BGE reranker does |
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
ref #8555
This adds initial support for reranking to
libllama
,llama-embeddings
andllama-server
. I've tested mainly with the following 2 models:The reranking is implemented as a pooling layer of type
LLAMA_POOLING_TYPE_RANK
. When used,libllama
will attach a classification head at the end of the graph:llama.cpp/src/llama.cpp
Lines 10246 to 10266 in 4d45775
The current implementation likely does not cover all types of rerankers, so updates would be necessary in the future to support other types of classifications on a case-by-case basis.
The computed rank scores for each sequence can be accessed via the
llama_get_embeddings_seq()
call:llama.cpp/include/llama.h
Lines 873 to 878 in 4d45775
The rank score is stored as a single float.
The server endpoint is designed mostly after https://jina.ai/reranker/, but it is not fully complete. Again, I think it's better to update it on a case-by-case basis + the API is not ideal (e.g. what is the purpose of
top_n
?, why are the document contents returned in the response?).I started to add server tests, but it will take me more time to write the python code, so I'll create a separate issue for people to help with that:
llama.cpp/examples/server/tests/features/rerank.feature
Lines 19 to 25 in 4d45775
TODO:
tests(left for follow-up PR)Model: https://huggingface.co/BAAI/bge-reranker-v2-m3
Testing:
python3 convert_hf_to_gguf.py \ ~/Data/huggingface/bge-reranker-v2-m3/ \ --outfile models/bge-reranker-v2-m3/ggml-model-f16.gguf \ --outtype f16
Classifier:
https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
Testing (CLI)
Rank responses:
./llama-embedding \ -m models/bge-reranker-v2-m3/ggml-model-f16.gguf \ -p "what is panda?</s><s>hi\nwhat is panda?</s><s>it's a bear\nwhat is panda?</s><s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." \ --pooling rank --embd-normalize -1 --verbose-prompt
Testing (server)
result: