Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add reranking support #9510

Merged
merged 25 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3453e62
py : add XLMRobertaForSequenceClassification [no ci]
ggerganov Sep 16, 2024
77723ed
py : fix scalar-tensor conversion [no ci]
ggerganov Sep 17, 2024
49f90de
py : fix position embeddings chop [no ci]
ggerganov Sep 17, 2024
dc0cdd8
llama : read new cls tensors [no ci]
ggerganov Sep 17, 2024
d0a7bf9
llama : add classigication head (wip) [no ci]
ggerganov Sep 18, 2024
125a067
llama : add "rank" pooling type
ggerganov Sep 19, 2024
6235c62
server : add rerank endpoint
ggerganov Sep 19, 2024
6916ed1
llama : aboud ggml_repeat during classification
ggerganov Sep 23, 2024
62a45d1
rerank : cleanup + comments
ggerganov Sep 25, 2024
7bde9a0
server : accept /rerank endpoint in addition to /v1/rerank [no ci]
ggerganov Sep 25, 2024
c62a39d
embedding : parse special tokens
ggerganov Sep 25, 2024
866c011
jina : support v1 reranker
ggerganov Sep 25, 2024
84f56f3
vocab : minor style
ggerganov Sep 25, 2024
00b3376
server : initiate tests for later
ggerganov Sep 26, 2024
877a04c
server : add docs
ggerganov Sep 26, 2024
4d45775
llama : add comment [no ci]
ggerganov Sep 26, 2024
ca99a6c
llama : fix uninitialized tensors
ggerganov Sep 26, 2024
f19554f
ci : add rerank tests
ggerganov Sep 26, 2024
f27dd69
add reranking test
ngxson Sep 26, 2024
1ae8376
change test data
ngxson Sep 26, 2024
84b0af8
Update examples/server/server.cpp
ggerganov Sep 27, 2024
0d6f6a7
add `--reranking` argument
ngxson Sep 27, 2024
a4ac45f
update server docs
ngxson Sep 27, 2024
39167b6
llama : fix comment [no ci]
ggerganov Sep 28, 2024
aeac876
Merge branch 'master' into gg/rerank
ggerganov Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 76 additions & 9 deletions ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,81 @@ function gg_run_embd_bge_small {
set +e
}

function gg_sum_embd_bge_small {
gg_printf '### %s\n\n' "${ci}"

gg_printf 'BGE Small (BERT):\n'
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
}

# rerank_tiny

function gg_run_rerank_tiny {
cd ${SRC}

gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json

gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json

path_models="../models-mnt/rerank-tiny"

rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release

set -e

(time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log

python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf

model_f16="${path_models}/ggml-model-f16.gguf"

(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s><s>hi\nwhat is panda?</s><s>it's a bear\nwhat is panda?</s><s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log

# sample output
# rerank score 0: 0.029
# rerank score 1: 0.029
# rerank score 2: 0.135

# check that the score is in the range [$3, $4]
function check_score {
qnt="$1"
score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)

if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then
printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4"
return 20
fi

printf ' - %s @ %s OK\n' "$qnt" "$score"
return 0
}

check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log

set +e
}

function gg_sum_rerank_tiny {
gg_printf '### %s\n\n' "${ci}"

gg_printf 'Rerank Tiny (Jina):\n'
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)"
}

function gg_check_build_requirements {
if ! command -v cmake &> /dev/null; then
gg_printf 'cmake not found, please install'
Expand All @@ -726,15 +801,6 @@ function gg_check_build_requirements {
fi
}

function gg_sum_embd_bge_small {
gg_printf '### %s\n\n' "${ci}"

gg_printf 'BGE Small (BERT):\n'
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
}

## main

export LLAMA_LOG_PREFIX=1
Expand Down Expand Up @@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release

if [ -z ${GG_BUILD_LOW_PERF} ]; then
test $ret -eq 0 && gg_run embd_bge_small
test $ret -eq 0 && gg_run rerank_tiny

if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
test $ret -eq 0 && gg_run test_scripts_debug
Expand Down
18 changes: 15 additions & 3 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0;
}

if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}

return true;
}

Expand Down Expand Up @@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.verbose_prompt = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
));
add_opt(llama_arg(
{"--no-display-prompt"},
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
Expand Down Expand Up @@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
}
).set_sparam());
add_opt(llama_arg(
{"--pooling"}, "{none,mean,cls,last}",
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",
[](gpt_params & params, const std::string & value) {
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
Expand Down Expand Up @@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(llama_arg(
{"--reranking", "--rerank"},
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
[](gpt_params & params) {
params.reranking = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(llama_arg(
{"--api-key"}, "KEY",
"API key to use for authentication (default: none)",
Expand Down
5 changes: 5 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;

if (params.reranking) {
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
}

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);

Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ struct gpt_params {
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embendings
bool reranking = false; // enable reranking support on server

// server params
int32_t port = 8080; // server listens on this network port
Expand Down
27 changes: 24 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,13 @@ def prepare_tensors(self):
bid = int(part)
break

for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
data: np.ndarray # type hint
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
data = data_torch.squeeze().numpy()

# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
data = data_torch.numpy()

n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)

Expand Down Expand Up @@ -592,6 +597,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
# ref: https://huggingface.co/databricks/dbrx-base
res = "dbrx"
if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
# ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
res = "jina-v1-en"
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
res = "jina-v2-en"
Expand Down Expand Up @@ -2601,7 +2609,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])


@Model.register("XLMRobertaModel")
@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT

Expand Down Expand Up @@ -2699,6 +2707,11 @@ def set_vocab(self):
self.gguf_writer.add_add_eos_token(True)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# if name starts with "roberta.", remove the prefix
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
if name.startswith("roberta."):
name = name[8:]

# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
Expand Down Expand Up @@ -3110,6 +3123,14 @@ def set_vocab(self):
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# if name starts with "bert.", remove the prefix
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
if name.startswith("bert."):
name = name[5:]

return super().modify_tensors(data_torch, name, bid)


@Model.register("OpenELMForCausalLM")
class OpenELMModel(Model):
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
Expand Down
7 changes: 6 additions & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
for (const auto & prompt : prompts) {
auto inp = ::llama_tokenize(ctx, prompt, true, false);
auto inp = ::llama_tokenize(ctx, prompt, true, true);
if (inp.size() > n_batch) {
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch);
Expand Down Expand Up @@ -234,6 +234,11 @@ int main(int argc, char ** argv) {
}
LOG("\n");
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
for (int j = 0; j < n_embd_count; j++) {
// NOTE: if you change this log - update the tests in ci/run.sh
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
Expand Down
39 changes: 38 additions & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
**Features:**
* LLM inference of F16 and quantized models on GPU and CPU
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
* Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510)
* Parallel decoding with multi-user support
* Continuous batching
* Multimodal (wip)
Expand All @@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co
| -------- | ----------- |
| `-h, --help, --usage` | print usage and exit |
| `--version` | show version and build info |
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
| `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
Expand Down Expand Up @@ -130,14 +132,15 @@ The project is under active development, and we are [looking for feedback and co
| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
| `-sp, --special` | special tokens output enabled (default: false) |
| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
| `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
| `--host HOST` | ip address to listen (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
Expand All @@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |


Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.

Example usage of docker compose with environment variables:
Expand Down Expand Up @@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does.

`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.

### POST `/reranking`: Rerank documents according to a given query

Similar to https://jina.ai/reranker/ but might change in the future.
Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options.

*Options:*

`query`: The query against which the documents will be ranked.

`documents`: An array strings representing the documents to be ranked.

*Aliases:*
- `/rerank`
- `/v1/rerank`
- `/v1/reranking`

*Examples:*

```shell
curl http://127.0.0.1:8012/v1/rerank \
-H "Content-Type: application/json" \
-d '{
"model": "some-model",
"query": "What is panda?",
"top_n": 3,
"documents": [
"hi",
"it is a bear",
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
]
}' | jq
```

### POST `/infill`: For code infilling.

Takes a prefix and a suffix and returns the predicted completion as stream.
Expand Down
Loading
Loading