diff --git a/ci/run.sh b/ci/run.sh index 1ac08ee4e19a8..7d241ecc0ea06 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -712,6 +712,81 @@ function gg_run_embd_bge_small { set +e } +function gg_sum_embd_bge_small { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'BGE Small (BERT):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" +} + +# rerank_tiny + +function gg_run_rerank_tiny { + cd ${SRC} + + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + + gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + + path_models="../models-mnt/rerank-tiny" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + + # sample output + # rerank score 0: 0.029 + # rerank score 1: 0.029 + # rerank score 2: 0.135 + + # check that the score is in the range [$3, $4] + function check_score { + qnt="$1" + score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$score" + return 0 + } + + check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log + + set +e +} + +function gg_sum_rerank_tiny { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Rerank Tiny (Jina):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)" +} + function gg_check_build_requirements { if ! command -v cmake &> /dev/null; then gg_printf 'cmake not found, please install' @@ -726,15 +801,6 @@ function gg_check_build_requirements { fi } -function gg_sum_embd_bge_small { - gg_printf '### %s\n\n' "${ci}" - - gg_printf 'BGE Small (BERT):\n' - gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" - gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" - gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" -} - ## main export LLAMA_LOG_PREFIX=1 @@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small + test $ret -eq 0 && gg_run rerank_tiny if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then test $ret -eq 0 && gg_run test_scripts_debug diff --git a/common/arg.cpp b/common/arg.cpp index 6880117ed8001..8266a16c261c5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx params.kv_overrides.back().key[0] = 0; } + if (params.reranking && params.embedding) { + throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); + } + return true; } @@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.verbose_prompt = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + )); add_opt(llama_arg( {"--no-display-prompt"}, format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), @@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_sparam()); add_opt(llama_arg( - {"--pooling"}, "{none,mean,cls,last}", + {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); @@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.embedding = true; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(llama_arg( + {"--reranking", "--rerank"}, + format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"), + [](gpt_params & params) { + params.reranking = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(llama_arg( {"--api-key"}, "KEY", "API key to use for authentication (default: none)", diff --git a/common/common.cpp b/common/common.cpp index 8d0ed4f95a737..e2b8574bf77d7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; + if (params.reranking) { + cparams.embeddings = true; + cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; + } + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/common/common.h b/common/common.h index cb87c4479ed0a..8b84cf9ad45ee 100644 --- a/common/common.h +++ b/common/common.h @@ -271,6 +271,7 @@ struct gpt_params { int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embendings + bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cd5a8c11bc18..96a8830e9e7a3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -291,8 +291,13 @@ def prepare_tensors(self): bid = int(part) break - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray # type hint + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + data = data_torch.squeeze().numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) @@ -592,6 +597,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": # ref: https://huggingface.co/databricks/dbrx-base res = "dbrx" + if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448": + # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + res = "jina-v1-en" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en res = "jina-v2-en" @@ -2601,7 +2609,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) -@Model.register("XLMRobertaModel") +@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -2699,6 +2707,11 @@ def set_vocab(self): self.gguf_writer.add_add_eos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: @@ -3110,6 +3123,14 @@ def set_vocab(self): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "bert.", remove the prefix + # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + if name.startswith("bert."): + name = name[5:] + + return super().modify_tensors(data_torch, name, bid) + @Model.register("OpenELMForCausalLM") class OpenELMModel(Model): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 4d11059f374d2..022354a3b624e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -81,6 +81,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a438dcb5adf34..7349268223827 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -135,7 +135,7 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; for (const auto & prompt : prompts) { - auto inp = ::llama_tokenize(ctx, prompt, true, false); + auto inp = ::llama_tokenize(ctx, prompt, true, true); if (inp.size() > n_batch) { LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", __func__, (long long int) inp.size(), (long long int) n_batch); @@ -234,6 +234,11 @@ int main(int argc, char ** argv) { } LOG("\n"); } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + for (int j = 0; j < n_embd_count; j++) { + // NOTE: if you change this log - update the tests in ci/run.sh + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + } } else { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { diff --git a/examples/server/README.md b/examples/server/README.md index dfca07f988824..951c4a44c6058 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal (wip) @@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `--verbose-prompt` | print a verbose prompt before generation (default: false) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | @@ -130,7 +132,7 @@ The project is under active development, and we are [looking for feedback and co | `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `-sp, --special` | special tokens output enabled (default: false) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | +| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | | `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | @@ -138,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | +| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | | `--api-key-file FNAME` | path to file containing API keys (default: none) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | @@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | + Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. Example usage of docker compose with environment variables: @@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. +### POST `/reranking`: Rerank documents according to a given query + +Similar to https://jina.ai/reranker/ but might change in the future. +Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options. + + *Options:* + + `query`: The query against which the documents will be ranked. + + `documents`: An array strings representing the documents to be ranked. + + *Aliases:* + - `/rerank` + - `/v1/rerank` + - `/v1/reranking` + + *Examples:* + + ```shell + curl http://127.0.0.1:8012/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "some-model", + "query": "What is panda?", + "top_n": 3, + "documents": [ + "hi", + "it is a bear", + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." + ] + }' | jq + ``` + ### POST `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 61ff09bb2b40f..f343cc252f89a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,6 +92,7 @@ enum server_task_type { enum server_task_cmpl_type { SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_CMPL_TYPE_INFILL, }; @@ -172,6 +173,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -954,8 +956,17 @@ struct server_context { slot.prompt = *prompt; } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { slot.prompt = prompt->at(0); + } else if (prompt->is_array() && prompt->size() > 1) { + // array of strings + for (const auto & el : *prompt) { + if (!el.is_string()) { + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + slot.prompt = *prompt; } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1389,6 +1400,7 @@ struct server_context { res.data = json { {"embedding", std::vector(n_embd, 0.0f)}, + {"index", slot.index}, }; continue; @@ -1407,6 +1419,44 @@ struct server_context { queue_results.send(res); } + void send_rerank(const server_slot & slot, const llama_batch & batch) { + server_task_result res; + res.id = slot.id_task; + res.error = false; + res.stop = true; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res.data = json { + {"index", slot.index}, + {"score", -1e6}, + }; + + continue; + } + + res.data = json { + {"index", slot.index}, + {"score", embd[0]}, + }; + } + + SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str()); + + queue_results.send(res); + } + // // Functions to create new task(s) and receive result(s) // @@ -1442,13 +1492,27 @@ struct server_context { // otherwise, it's a multiple-prompt task, we break it into smaller tasks else if (prompt.is_array()) { std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); + if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // prompts[0] is the question + // the rest are the answers/documents + SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); + for (size_t i = 1; i < prompts.size(); i++) { + json qd; + qd.push_back(prompts[0]); + qd.push_back(prompts[i]); + data["index"] = i - 1; + create_task(data, true, qd); + } + } else { + SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } } } } @@ -1492,7 +1556,9 @@ struct server_context { return; } - size_t idx = result.data["index"]; + const size_t idx = result.data["index"]; + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = result; } result_handler(results); @@ -1903,6 +1969,7 @@ struct server_context { // track if this is an embedding or non-embedding batch // if we've added sampled tokens above, we are in non-embedding mode // -1: none, 0: non-embedding, 1: embedding + // TODO: make enum int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; // next, batch any pending prompts without exceeding n_batch @@ -1951,6 +2018,29 @@ struct server_context { } prompt_tokens = embd_inp; + } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // require slot.prompt to be array of 2 strings + if (!slot.prompt.is_array() || slot.prompt.size() != 2) { + SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); + slot.release(); + send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); + continue; + } + + // prompt: querydoc + prompt_tokens.clear(); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[0], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[1], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } @@ -1970,7 +2060,7 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { slot.release(); @@ -2048,7 +2138,8 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2056,7 +2147,10 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; + const bool slot_type = + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2229,6 +2323,13 @@ struct server_context { continue; // continue loop of slots } + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { @@ -2787,8 +2888,8 @@ int main(int argc, char ** argv) { }; const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2848,8 +2949,8 @@ int main(int argc, char ** argv) { // TODO: maybe merge this function with "handle_completions_generic" const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2973,6 +3074,11 @@ int main(int argc, char ** argv) { }; const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + // TODO: somehow clean up this checks in the future + if (!ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } const json body = json::parse(req.body); bool is_openai = false; @@ -3023,6 +3129,79 @@ int main(int argc, char ** argv) { res_ok(res, root); }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::vector documents = json_value(body, "documents", std::vector()); + if (documents.empty()) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // construct prompt object: array of ["query", "doc0", "doc1", ...] + json prompt; + prompt.push_back(query); + for (const auto & doc : documents) { + prompt.push_back(doc); + } + + LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank(body, responses); + res_ok(res, root); + }; + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.loras.size(); ++i) { @@ -3119,6 +3298,10 @@ int main(int argc, char ** argv) { svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/rerank", handle_rerank); + svr->Post("/reranking", handle_rerank); + svr->Post("/v1/rerank", handle_rerank); + svr->Post("/v1/reranking", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); // LoRA adapters hotswap diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature index 818ea3beb90cd..f4fe2ee4335ff 100644 --- a/examples/server/tests/features/embeddings.feature +++ b/examples/server/tests/features/embeddings.feature @@ -15,7 +15,7 @@ Feature: llama.cpp server And 128 as batch size And 128 as ubatch size And 512 KV cache size - And embeddings extraction + And enable embeddings endpoint Then the server is starting Then the server is healthy diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature new file mode 100644 index 0000000000000..c36cc8e215fa6 --- /dev/null +++ b/examples/server/tests/features/rerank.feature @@ -0,0 +1,42 @@ +@llama.cpp +@rerank +Feature: llama.cpp server + + Background: Server startup + Given a server listening on localhost:8080 + And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf + And a model file jina-reranker-v1-tiny-en.gguf + And a model alias jina-reranker-v1-tiny-en + And 42 as server seed + And 2 slots + And 512 as batch size + And 512 as ubatch size + And 512 KV cache size + And enable reranking endpoint + Then the server is starting + Then the server is healthy + + Scenario: Rerank + Given a rerank query: + """ + Machine learning is + """ + And a rerank document: + """ + A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines. + """ + And a rerank document: + """ + Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants. + """ + And a rerank document: + """ + Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions. + """ + And a rerank document: + """ + Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine. + """ + When reranking request + Then reranking results are returned + Then reranking highest score is index 2 and lowest score is index 3 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 0fea0fe87b799..2611614ba3633 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.server_api_key = None context.server_continuous_batching = False context.server_embeddings = False + context.server_reranking = False context.server_metrics = False context.server_process = None context.seed = None @@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.concurrent_tasks = [] context.prompts = [] + context.reranking_query = None + context.reranking_documents = [] + context.reranking_results = None + @step('a model file {hf_file} from HF repo {hf_repo}') def step_download_hf_model(context, hf_file: str, hf_repo: str): @@ -172,10 +177,13 @@ def step_server_continuous_batching(context): context.server_continuous_batching = True -@step('embeddings extraction') +@step('enable embeddings endpoint') def step_server_embeddings(context): context.server_embeddings = True +@step('enable reranking endpoint') +def step_server_reranking(context): + context.server_reranking = True @step('prometheus compatible metrics exposed') def step_server_metrics(context): @@ -452,6 +460,14 @@ def step_impl(context, n_ga_w): def step_prompt_passkey(context): context.prompt_passkey = context_text(context) +@step('a rerank query') +def step_set_rerank_query(context): + context.reranking_query = context_text(context) + context.reranking_documents = [] + +@step('a rerank document') +def step_set_rerank_document(context): + context.reranking_documents.append(context_text(context)) @step('{n_prompts:d} fixed prompts') def step_fixed_prompts(context, n_prompts): @@ -619,6 +635,22 @@ async def step_compute_embedding(context): context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url) +@step('reranking request') +@async_run_until_complete +async def step_compute_reranking(context): + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + async with session.post(f'{context.base_url}/reranking', + json={ + "query": context.reranking_query, + "documents": context.reranking_documents, + }) as response: + if response.status == 200: + response_json = await response.json() + context.reranking_results = response_json['results'] + else: + context.reranking_results = response.status + + @step('all embeddings are the same') @async_run_until_complete async def step_all_embeddings_are_the_same(context): @@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context): for i in range(n_embedding_requests): assert_embeddings(context.tasks_result.pop().pop()) +@step('reranking results are returned') +def reranking_results_are_returned(context): + assert len(context.reranking_results) == len(context.reranking_documents) + +@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}') +def reranking_results_are_returned(context, idx_high: int, idx_low: int): + max_score, max_idx = 0, 0 + min_score, min_idx = 0, 0 + for res in context.reranking_results: + if max_score < res['relevance_score']: + max_score = res['relevance_score'] + max_idx = res['index'] + if min_score > res['relevance_score']: + min_score = res['relevance_score'] + min_idx = res['index'] + print(context.reranking_results) + assert max_idx == idx_high + assert min_idx == idx_low @step('adding special tokens') def step_tokenize_set_add_special(context): @@ -1362,6 +1412,8 @@ def start_server_background(context): server_args.append('--cont-batching') if context.server_embeddings: server_args.append('--embedding') + if context.server_reranking: + server_args.append('--reranking') if context.server_metrics: server_args.append('--metrics') if context.model_alias: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f093f547ff2c1..47dfdfde512dc 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { + {"usage", json { // TODO: fill {"prompt_tokens", 0}, {"total_tokens", 0} }}, @@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { // TODO: fill + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"results", data} + }; + + return res; +} + static bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2fd2e9d2be828..ebe66a4a39f5f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -345,6 +345,8 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -504,6 +506,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -613,6 +617,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, @@ -644,6 +650,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5ef91f11d312f..e7e9b6fd5efbc 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -679,6 +679,15 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.CLS: ( + "classifier", # jina + "classifier.dense", # roberta + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 4ea8a2c2b664b..7cae1bbe2e5b8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -193,6 +193,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph }; enum llama_attention_type { @@ -202,9 +203,9 @@ extern "C" { }; enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -872,7 +873,8 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // shape: [n_embd] (1-dimensional) + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); // diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e4d844a73c216..d2f34ddd6b339 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1554,7 +1554,7 @@ std::vector llama_tokenize_internal( } break; case LLAMA_VOCAB_TYPE_UGM: { - if (add_special && vocab.tokenizer_add_bos != 0) { + if (add_special && vocab.tokenizer_add_bos) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } @@ -1572,14 +1572,14 @@ std::vector llama_tokenize_internal( } } - if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (add_special && vocab.tokenizer_add_eos == 1) { + if (add_special && vocab.tokenizer_add_eos) { GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } @@ -1791,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = token_text; llama_unescape_whitespace(result); return _try_copy(result.data(), result.size()); - } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { + } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { char byte = (char) llama_token_to_byte(vocab, token); return _try_copy((char*) &byte, 1); } @@ -1806,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } diff --git a/src/llama.cpp b/src/llama.cpp index 512c61e7deb19..467f8a5852b72 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -606,6 +606,8 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -793,6 +795,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { @@ -828,6 +832,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, }, }, { @@ -2894,6 +2899,7 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + // TODO: should init all tensors to nullptr struct ggml_tensor * tok_embd; struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; @@ -2906,6 +2912,12 @@ struct llama_model { struct ggml_tensor * output_b; struct ggml_tensor * output_norm_enc; + // classifier + struct ggml_tensor * cls; + struct ggml_tensor * cls_b; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; + std::vector layers; llama_split_mode split_mode; @@ -5604,11 +5616,11 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base } } break; @@ -6313,6 +6325,7 @@ static void llm_load_vocab( tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || + tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-code") { @@ -6439,7 +6452,12 @@ static void llm_load_vocab( for (uint32_t i = 0; i < n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); @@ -6520,8 +6538,14 @@ static void llm_load_vocab( vocab.linefeed_id = ids[0]; } else { const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + vocab.linefeed_id = vocab.special_pad_id; + } else { + vocab.linefeed_id = ids[0]; + } } // special tokens @@ -7398,6 +7422,12 @@ static bool llm_load_tensors( if (model.arch == LLM_ARCH_BERT) { model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); @@ -7450,6 +7480,8 @@ static bool llm_load_tensors( model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); @@ -10283,6 +10315,10 @@ struct llm_build_context { struct ggml_tensor * cur; switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; case LLAMA_POOLING_TYPE_MEAN: { struct ggml_tensor * inp_mean = build_inp_mean(); @@ -10294,9 +10330,26 @@ struct llm_build_context { struct ggml_tensor * inp_cls = build_inp_cls(); cur = ggml_get_rows(ctx0, inp, inp_cls); } break; - case LLAMA_POOLING_TYPE_NONE: + case LLAMA_POOLING_TYPE_RANK: { - cur = inp; + struct ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + GGML_ASSERT(model.cls != nullptr); + GGML_ASSERT(model.cls_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); + cur = ggml_tanh(ctx0, cur); + + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (model.cls_out) { + GGML_ASSERT(model.cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + } } break; default: { @@ -11525,8 +11578,8 @@ struct llm_build_context { inpL = cur; } - // final output cur = inpL; + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -16686,7 +16739,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { const int64_t n_tokens = batch.n_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seqs = batch.n_seqs; @@ -16701,7 +16756,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const llama_seq_id seq_id = batch.seq_id[s][0]; // TODO: adapt limits to n_seqs when batch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); for (int i = 0; i < n_seq_tokens; ++i) { const llama_pos pos = batch.pos[s*n_seq_tokens + i]; @@ -17241,6 +17296,20 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = lctx.embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); @@ -17447,6 +17516,13 @@ static int llama_encode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type");