diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a0ca9d98c978b..18d6512608901 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -236,7 +236,7 @@ int main(int argc, char ** argv) { } } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { for (int j = 0; j < n_embd_count; j++) { - LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]); + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); } } else { // print the first part of the embeddings or for a single prompt, the full embedding diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71d29002d74a0..ce65164d1f2ac 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1419,7 +1419,7 @@ struct server_context { queue_results.send(res); } - void send_rank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot, const llama_batch & batch) { server_task_result res; res.id = slot.id_task; res.error = false; @@ -1440,7 +1440,7 @@ struct server_context { res.data = json { {"index", slot.index}, - {"rank", -1e6}, + {"score", -1e6}, }; continue; @@ -1448,11 +1448,11 @@ struct server_context { res.data = json { {"index", slot.index}, - {"rank", embd[0]}, + {"score", embd[0]}, }; } - SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str()); + SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str()); queue_results.send(res); } @@ -1493,6 +1493,9 @@ struct server_context { else if (prompt.is_array()) { std::vector prompts = prompt; if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // prompts[0] is the question + // the rest are the answers/documents + SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); for (size_t i = 1; i < prompts.size(); i++) { json qd; qd.push_back(prompts[0]); @@ -1501,6 +1504,7 @@ struct server_context { create_task(data, true, qd); } } else { + SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); for (size_t i = 0; i < prompts.size(); i++) { const auto & e = prompts[i]; if (e.is_string() || json_is_array_of_numbers(e)) { @@ -1965,6 +1969,7 @@ struct server_context { // track if this is an embedding or non-embedding batch // if we've added sampled tokens above, we are in non-embedding mode // -1: none, 0: non-embedding, 1: embedding + // TODO: make enum int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; // next, batch any pending prompts without exceeding n_batch @@ -2133,6 +2138,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } + // non-causal tasks require to fit the entire prompt in the physical batch if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { @@ -2318,7 +2324,7 @@ struct server_context { } if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - send_rank(slot, batch_view); + send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; continue; // continue loop of slots diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 91e7f792d28d6..47dfdfde512dc 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) { for (const auto & rank : ranks) { data.push_back(json{ {"index", i++}, - {"relevance_score", json_value(rank, "rank", 0.0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, }); } diff --git a/include/llama.h b/include/llama.h index 6601b3444864a..94341d78acb54 100644 --- a/include/llama.h +++ b/include/llama.h @@ -192,7 +192,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, - LLAMA_POOLING_TYPE_RANK = 4, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph }; enum llama_attention_type { @@ -202,9 +202,9 @@ extern "C" { }; enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -872,7 +872,8 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // shape: [n_embd] (1-dimensional) + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); // diff --git a/src/llama.cpp b/src/llama.cpp index f0f7b67cf801c..b7c0fa4f4bf23 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17009,7 +17009,7 @@ static int llama_decode_internal( } break; case LLAMA_POOLING_TYPE_RANK: { - // extract the rank score - a single float per sequence + // extract the rerank score - a single float per sequence auto & embd_seq_out = lctx.embd_seq; for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { @@ -17211,7 +17211,6 @@ static int llama_encode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: - case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq; @@ -17228,6 +17227,13 @@ static int llama_encode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type");