Skip to content

Commit

Permalink
refactor: support rerank usage
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <thxcode0824@gmail.com>
  • Loading branch information
thxCode committed Oct 7, 2024
1 parent 218db66 commit ef92202
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
1 change: 1 addition & 0 deletions llama-box/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@ struct server_context {
};
}

res.data["tokens_evaluated"] = slot.n_prompt_tokens;
queue_results.send(res);
}

Expand Down
23 changes: 14 additions & 9 deletions llama-box/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,13 @@ static int32_t jinaicompat_rerank_top_n_response_partition(json &result, int32_t

static json jinaicompat_rerank_response(const json &request, json &result) {
int32_t top_n = json_value(request, "top_n", 1);

int num_prompt_tokens = 0;
json prompt = request.at("prompt");
json data = json::array();
if (top_n == int32_t(prompt.size()) - 1) {
for (const json &ret : result) {
num_prompt_tokens += json_value(ret, "tokens_evaluated", 0);
const int32_t idx = json_value(ret, "index", 0);
const double scr = json_value(ret, "score", 0.0);
data.push_back(json{
Expand All @@ -766,19 +769,21 @@ static json jinaicompat_rerank_response(const json &request, json &result) {
index = jinaicompat_rerank_top_n_response_partition(result, start, end);
}
}
for (int32_t i = 0; i < top_n; i++) {
for (int32_t i = 0; i < int32_t(result.size()); i++) {
const json &ret = result[i];
const int32_t idx = json_value(ret, "index", 0);
const double scr = json_value(ret, "score", 0.0);
data.push_back(json{
{"index", idx},
{"document", {{"text", prompt[idx + 1]}}},
{"relevance_score", scr},
});
num_prompt_tokens += json_value(ret, "tokens_evaluated", 0);
if (i < top_n) {
const int32_t idx = json_value(ret, "index", 0);
const double scr = json_value(ret, "score", 0.0);
data.push_back(json{
{"index", idx},
{"document", {{"text", prompt[idx + 1]}}},
{"relevance_score", scr},
});
}
}
}

int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
json res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
Expand Down

0 comments on commit ef92202

Please sign in to comment.