Skip to content

Commit

Permalink
refactor: prompt process
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <thxcode0824@gmail.com>
  • Loading branch information
thxCode committed Jul 12, 2024
1 parent 2c6e5cf commit dcc21f9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
7 changes: 4 additions & 3 deletions llama-box/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3199,9 +3199,10 @@ int main(int argc, char **argv) {
}

json request = json::parse(req.body);
if (!request.contains("messages")) {
res_error(res, format_error_response("\"messages\" must be provided",
ERROR_TYPE_INVALID_REQUEST));
if (!request.contains("messages") || !request.at("messages").is_array()) {
res_error(res,
format_error_response("\"messages\" must be provided and must be an array",
ERROR_TYPE_INVALID_REQUEST));
return;
}
request = oaicompat_completion_request(ctx_server.model, request, params.chat_template);
Expand Down
69 changes: 36 additions & 33 deletions llama-box/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,36 +103,34 @@ static inline void server_log(const char *level, const char *function, int line,
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model *model, const std::string &tmpl,
const std::vector<json> &messages) {
size_t alloc_size = 0;
// vector holding all allocated string to be passed to llama_chat_apply_template
std::vector<std::string> str(messages.size() * 2);
std::vector<llama_chat_message> chat(messages.size());

for (size_t i = 0; i < messages.size(); ++i) {
const auto &curr_msg = messages[i];
str[i * 2 + 0] = json_value(curr_msg, "role", std::string(""));
str[i * 2 + 1] = json_value(curr_msg, "content", std::string(""));
alloc_size += str[i * 2 + 1].length();
chat[i].role = str[i * 2 + 0].c_str();
chat[i].content = str[i * 2 + 1].c_str();
}

const char *ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2);

// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true,
buf.data(), buf.size());
std::vector<llama_chat_msg> chat;

for (const auto &curr_msg : messages) {
std::string role = json_value(curr_msg, "role", std::string(""));

std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const json &part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: "
"https://github.com/ggerganov/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error(
"Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}

// if it turns out that our buffer is too small, we resize it
if ((size_t)res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(),
buf.size());
chat.push_back({role, content});
}

const std::string formatted_chat(buf.data(), res);
return formatted_chat;
return llama_chat_apply_template(model, tmpl, chat, true);
}

//
Expand Down Expand Up @@ -430,21 +428,26 @@ static json oaicompat_completion_request(const struct llama_model *model, const
// Apply chat template to the list of messages
if (chat) {
const json messages = body.at("messages");
bool has_array_content = false;
bool chat_vision = false;
for (const json &msg : messages) {
if (msg.at("content").is_array()) {
has_array_content = true;
break;
if (!msg.contains("content") || !msg.at("content").is_array()) {
continue;
}
for (const json &part : msg.at("content")) {
if (part.contains("type") && part.at("type") == "image_url") {
chat_vision = true;
break;
}
}
}
if (!has_array_content) {
if (!chat_vision) {
llama_params["prompt"] = format_chat(model, chat_template, messages);
} else {
llama_params["__oaicompat_completion_chat_vision"] = true;
// Parse the vision messages,
// see https://platform.openai.com/docs/guides/vision
for (const json &msg : messages) {
if (msg.at("role") == "user") {
if (msg.contains("role") && msg.at("role") == "user") {
llama_params["prompt"] = msg.at("content");
break;
}
Expand Down

0 comments on commit dcc21f9

Please sign in to comment.