Skip to content

Commit

Permalink
server: fix a race condition cause by "request_completion"
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Jan 23, 2024
1 parent d083c81 commit 8f36df8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 24 deletions.
60 changes: 39 additions & 21 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1122,9 +1122,10 @@ struct llama_server_context
queue_results.send(res);
}

int request_completion(json data, bool infill, bool embedding, int multitask_id)
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{
task_server task;
task.id = task_id;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
Expand All @@ -1135,11 +1136,11 @@ struct llama_server_context
// when a completion task's prompt array is not a singleton, we split it into multiple requests
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
{
return split_multiprompt_task(task);
split_multiprompt_task(task_id, task);
}

// otherwise, it's a single-prompt task, we actually queue it
return queue_tasks.post(task);
queue_tasks.post(task);
}

// for multiple images processing
Expand Down Expand Up @@ -1218,25 +1219,30 @@ struct llama_server_context
queue_tasks.post(task);
}

int split_multiprompt_task(task_server& multiprompt_task)
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
{
int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);

int multitask_id = queue_tasks.get_next_id();
// generate all the ID for subtask
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
{
subtask_ids[i] = queue_tasks.get_new_id();
}

// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(multitask_id, subtask_ids);

// add subtasks
for (int i = 0; i < prompt_count; i++)
{
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];

// subtasks inherit everything else (infill mode, embedding mode, etc.)
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
}

// queue up the multitask so we can track its subtask progression
queue_tasks.add_multitask(multitask_id, subtask_ids);
return multitask_id;
}

void process_single_task(task_server& task)
Expand Down Expand Up @@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
return;
}
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
Expand All @@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
{
Expand Down Expand Up @@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
break;
}
}
sink.done();

llama.queue_results.remove_waiting_task_id(task_id);
sink.done();
return true;
};

Expand Down Expand Up @@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
}
json data = oaicompat_completion_params_parse(json::parse(req.body));

const int task_id = llama.request_completion(data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);

if (!json_value(data, "stream", false)) {
std::string completion_text;
Expand All @@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
Expand Down Expand Up @@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
return;
}
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, true, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, true, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
Expand All @@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
return;
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while (true)
Expand All @@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
});
if (!sink.write(str.c_str(), str.size()))
{
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
if (result.stop)
Expand All @@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
}
}

llama.queue_results.remove_waiting_task_id(task_id);
sink.done();

return true;
};

Expand Down Expand Up @@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
image_data = "";
}

const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);

// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);

// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});

Expand Down
8 changes: 5 additions & 3 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ struct llama_server_queue {
// Add a new task to the end of the queue
int post(task_server task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
task.id = id++;
if (task.id == -1) {
task.id = id++;
}
queue_tasks.push_back(std::move(task));
condition_tasks.notify_one();
return task.id;
Expand All @@ -215,8 +217,8 @@ struct llama_server_queue {
queue_tasks_deferred.push_back(std::move(task));
}

// Get the next task id
int get_next_id() {
// Get the next id for creating anew task
int get_new_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return id++;
}
Expand Down

0 comments on commit 8f36df8

Please sign in to comment.