Skip to content

Commit

Permalink
Use single thread per model instance (#2339)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Aug 20, 2024
1 parent 0e6e897 commit 0a1f65a
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 118 deletions.
57 changes: 20 additions & 37 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,26 +318,17 @@ def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0):
self.nccl_params = tm_model.nccl_params

# create model instances
model_insts = [None] * self.gpu_count
with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
for device_id in range(self.gpu_count):
futures.append(
executor.submit(self._create_model_instance, device_id,
model_insts))
for future in futures:
future.result()
self.model_inst = self._create_model_instance(0)

self.model_insts = model_insts
self.que = Queue()
self.executor: ThreadPoolExecutor = None
self.futures = [None] * self.gpu_count
self.future = None

def _create_model_instance(self, device_id, model_insts):
def _create_model_instance(self, device_id):
rank = self.node_id * self.gpu_count + device_id
model_inst = self.tm_model.model_comm.create_model_instance(
device_id, rank, self.cuda_stream_id, self.nccl_params)
model_insts[device_id] = model_inst
return model_inst

def _forward_callback(self, result, ctx):
self.que.put((False, result))
Expand All @@ -346,15 +337,12 @@ def _forward_thread(self, inputs):
instance_comm = self.tm_model.model_comm.create_instance_comm(
self.gpu_count)

def _func(device_id, enque_output):
output = self.model_insts[device_id].forward(inputs, instance_comm)
if enque_output:
self.que.put((True, output))
def _func():
output = self.model_inst.forward(inputs, instance_comm)
self.que.put((True, output))

self.executor = ThreadPoolExecutor(self.gpu_count)
for device_id in range(self.gpu_count):
f = self.executor.submit(_func, device_id, device_id == 0)
self.futures[device_id] = f
self.executor = ThreadPoolExecutor(1)
self.future = self.executor.submit(_func)

def _async_forward_callback(self, result, ctx, que: LifoQueue):
que.put((False, result))
Expand All @@ -363,15 +351,12 @@ def _async_forward_thread(self, inputs, que: LifoQueue):
instance_comm = self.tm_model.model_comm.create_instance_comm(
self.gpu_count)

def _func(device_id, enque_output):
output = self.model_insts[device_id].forward(inputs, instance_comm)
if enque_output:
que.put((True, output))
def _func():
output = self.model_inst.forward(inputs, instance_comm)
que.put((True, output))

self.executor = ThreadPoolExecutor(self.gpu_count)
for device_id in range(self.gpu_count):
f = self.executor.submit(_func, device_id, device_id == 0)
self.futures[device_id] = f
self.executor = ThreadPoolExecutor(1)
self.future = self.executor.submit(_func)

def _get_logprobs(self,
logprob_vals: torch.Tensor,
Expand Down Expand Up @@ -617,7 +602,7 @@ async def async_stream_infer(self,
_forward_thread = partial(self._async_forward_thread, que=que)
if stream_output and not stop:
logger.info(f'Register stream callback for {session_id}')
self.model_insts[0].register_callback(_forward_callback)
self.model_inst.register_callback(_forward_callback)

inputs, input_lengths = self.prepare_inputs(
session_id=session_id,
Expand Down Expand Up @@ -691,14 +676,13 @@ async def async_stream_infer(self,
yield outputs

if finish:
for f in self.futures:
f.result()
self.future.result()
self.executor.shutdown()
break

if stream_output and not stop:
logger.info(f'UN-register stream callback for {session_id}')
self.model_insts[0].unregister_callback()
self.model_inst.unregister_callback()

def stream_infer(self,
session_id,
Expand Down Expand Up @@ -730,7 +714,7 @@ def stream_infer(self,
"""
if stream_output and not stop:
logger.info(f'Register stream callback for {session_id}')
self.model_insts[0].register_callback(self._forward_callback)
self.model_inst.register_callback(self._forward_callback)

inputs, input_lengths = self.prepare_inputs(
session_id=session_id,
Expand Down Expand Up @@ -803,16 +787,15 @@ def stream_infer(self,
yield outputs

if finish:
for f in self.futures:
f.result()
self.future.result()
self.executor.shutdown()
while self.que.qsize() > 0:
self.que.get()
break

if stream_output and not stop:
logger.info(f'UN-register stream callback for {session_id}')
self.model_insts[0].unregister_callback()
self.model_inst.unregister_callback()

def decode(self,
input_ids,
Expand Down
78 changes: 48 additions & 30 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r
if (r) {
int ec = 0;

const int input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
const int input_length = r->inputs.getVal<int>("input_lengths", 0);
const auto get_offset = [&](int token_count) {
return std::max(0, std::min(token_count, r->inputs[rank_].getVal<int>("step", token_count)));
return std::max(0, std::min(token_count, r->inputs.getVal<int>("step", token_count)));
};

if (occurrence[r->id] != 1) {
Expand Down Expand Up @@ -249,7 +249,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)

auto& seq = *state.sequences[idx];

if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
if (int step = r->inputs.getVal<int>("step", -1); step >= 0) {
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
Expand All @@ -261,8 +261,8 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
}
}

const int input_length = r->inputs[rank_].getVal<int>("input_lengths");
const int* input_ids = r->inputs[rank_].getPtr<int>("input_ids");
const int input_length = r->inputs.getVal<int>("input_lengths");
const int* input_ids = r->inputs.getPtr<int>("input_ids");

{
// `output_ids` contains all token ids of the sequences
Expand All @@ -285,16 +285,16 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
}

// copy input tokens to prompt for prefix matching
if (input_length && r->start_flag && !r->inputs[rank_].isExist("input_embedding_ranges")) {
if (input_length && r->start_flag && !r->inputs.isExist("input_embedding_ranges")) {
// TODO: truncate prompt to enable prefix caching for VLM
seq.prompt.resize(input_length);
std::copy_n(input_ids, input_length, seq.prompt.data());
}

// copy input embeddings
if (r->inputs[rank_].isExist("input_embedding_ranges")) {
const auto range_tensor = r->inputs[rank_].at("input_embedding_ranges");
const auto emb_tensor = r->inputs[rank_].at("input_embeddings");
if (r->inputs.isExist("input_embedding_ranges")) {
const auto range_tensor = r->inputs.at("input_embedding_ranges");
const auto emb_tensor = r->inputs.at("input_embeddings");
const int* ranges = range_tensor.getPtr<int>();

auto check_embeddings = [&](int& num_valid_embeddings) {
Expand Down Expand Up @@ -332,7 +332,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
range_tensor.toString().c_str());
}
else {
char* emb_tensor_ptr = emb_tensor.getPtr<char>();
const char* emb_tensor_ptr = emb_tensor.getPtr<char>();
for (size_t i = 0; i < num_valid_embeddings; i++) {
int begin = ranges[i * 2];
int end = ranges[i * 2 + 1];
Expand All @@ -344,7 +344,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
}
}

const int request_output_len = state.requests[idx]->inputs[rank_].getVal<int>("request_output_len");
const int request_output_len = state.requests[idx]->inputs.getVal<int>("request_output_len");
state.seq_len_limit[idx] = state.h_context_length[idx] + request_output_len;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
Expand Down Expand Up @@ -386,7 +386,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)

if (r->start_flag) {
// prepare to initialize random state for new sequence
h_random_seed_[idx] = r->inputs[rank_].getVal<unsigned long long>("random_seed", 0);
h_random_seed_[idx] = r->inputs.getVal<unsigned long long>("random_seed", 0);
}
else {
// Recover device states if not a new sequence
Expand Down Expand Up @@ -1045,8 +1045,8 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
// find an exemplar that matches the param name
const Tensor* ptr{};
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(name)) {
ptr = &state_->requests[i]->inputs[rank_].at(name);
if (state_->requests[i]->inputs.isExist(name)) {
ptr = &state_->requests[i]->inputs.at(name);
break;
}
}
Expand All @@ -1061,8 +1061,8 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
int max_list_length = 0;
if (name == "bad_words_list" || name == "stop_words_list") {
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]->inputs[rank_].isExist(name)) {
Tensor& src = state_->requests[i]->inputs[rank_].at(name);
if (state_->requests[i]->inputs.isExist(name)) {
Tensor& src = state_->requests[i]->inputs.at(name);
FT_CHECK(src.shape.size() == 3 && src.shape[1] == 2 && src.shape[2] <= kMaxStopBadWordsLen);
max_list_length = std::max(max_list_length, (int)src.shape[2]);
}
Expand All @@ -1075,8 +1075,8 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
}
for (int i = 0; i < batch_size; ++i) {
FT_CHECK(state_->requests[i] != nullptr);
if (state_->requests[i]->inputs[rank_].isExist(name)) {
Tensor& src = state_->requests[i]->inputs[rank_].at(name);
if (state_->requests[i]->inputs.isExist(name)) {
Tensor& src = state_->requests[i]->inputs.at(name);
if (name == "bad_words_list" || name == "stop_words_list") {
int list_length = src.shape[2];
std::copy_n(src.getPtr<std::byte>(),
Expand Down Expand Up @@ -1127,7 +1127,7 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)

TensorMap outputs;
for (int i = 0; i < batch_size; i++) {
if (state_->requests[i]->inputs[rank_].isExist("logprobs")) {
if (state_->requests[i]->inputs.isExist("logprobs")) {
outputs.insert(
{"sampled_logprobs", {MEMORY_GPU, TYPE_FP32, {(size_t)batch_size, 1, kMaxLogProb}, sampled_logprobs_}});
outputs.insert(
Expand All @@ -1153,7 +1153,7 @@ void LlamaBatch<T>::OutputContextLogits(T* cont
bool is_return_logits = false;
for (int k = 0; k < indices.size(); ++k) {
auto& request = state_->requests[indices[k]];
auto logits = request->outputs[rank_].getPtr<float>("logits", nullptr);
auto logits = request->outputs.getPtr<float>("logits", nullptr);
if (logits && sequences[k]->cache_len + lengths[k] <= sequences[k]->tokens.size()) {
logits = nullptr;
}
Expand Down Expand Up @@ -1185,6 +1185,11 @@ void LlamaBatch<T>::OutputContextLogits(T* cont

auto logits = context_logits_buf_;

// Only rank-0 writes to output
if (rank_ != 0) {
return;
}

for (int k = 0; k < indices.size(); ++k) {
if (output_logits[k]) {
auto src_ptr = logits;
Expand Down Expand Up @@ -1250,15 +1255,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
++state_->h_context_length[i];
}

{ // output logprobs, should be set before sequence_length
// ! Only rank-0 writes to output
if (rank_ == 0) {
// output logprobs, should be set before sequence_length
float* sampled_logprobs_ptr = h_sampled_logprobs_;
uint32_t* sampled_indexes_ptr = h_sampled_indexes_;
uint32_t* sampled_nums_ptr = h_sampled_nums_;
for (int i = 0; i < batch_size - g.partial; ++i) {
if (state_->requests[i] && state_->requests[i]->inputs[rank_].isExist("logprobs")) {
auto logprob_vals = state_->requests[i]->outputs[rank_].getPtr<float>("logprob_vals");
auto logprob_indexes = state_->requests[i]->outputs[rank_].getPtr<uint32_t>("logprob_indexes");
auto logprob_nums = state_->requests[i]->outputs[rank_].getPtr<uint32_t>("logprob_nums");
if (state_->requests[i] && state_->requests[i]->inputs.isExist("logprobs")) {
auto logprob_vals = state_->requests[i]->outputs.getPtr<float>("logprob_vals");
auto logprob_indexes = state_->requests[i]->outputs.getPtr<uint32_t>("logprob_indexes");
auto logprob_nums = state_->requests[i]->outputs.getPtr<uint32_t>("logprob_nums");

int offset = state_->h_context_length[i] - state_->h_prompt_length[i] - 1;
std::copy(sampled_logprobs_ptr,
Expand All @@ -1275,12 +1282,14 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
}
}

{ // set output tokens ids and sequence length
// ! Only rank-0 writes to output
if (rank_ == 0) {
// set output tokens ids and sequence length
int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size - g.partial; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
auto output_ids = state_->requests[i]->outputs[rank_].getPtr<int>("output_ids");
auto output_len = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
auto output_ids = state_->requests[i]->outputs.getPtr<int>("output_ids");
auto output_len = state_->requests[i]->outputs.getPtr<int>("sequence_length");
const int count = state_->h_context_length[i];
// TODO: sync history output tokens at when receiving the request and copy the last token here
std::copy(output_ptr, output_ptr + count, output_ids);
Expand Down Expand Up @@ -1322,7 +1331,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
signals.push_back([this, r = state_->requests[i]] {
if (rank_ == 0) {
try {
r->stream_cb(&r->outputs[rank_].get());
r->stream_cb(&r->outputs.get());
}
catch (const std::bad_function_call& e) {
TM_LOG_ERROR("Null stream callback for (%s)", std::to_string(r->id).c_str());
Expand Down Expand Up @@ -1379,7 +1388,16 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig

// Update token IDs
seq.tokens.resize(output_len);
const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>();
const auto output_ids_data = [&] {
if (force_stop) {
// `h_output_ids_` is UNDEFINED at `ProcessStopRequests`
return state_->requests[index]->outputs.at("output_ids").getPtr<int>();
}
else {
// `h_output_ids_` just updated by `Finish`, but `outputs` is NOT synced atm
return h_output_ids_ + index * (size_t)session_len_;
}
}();
std::copy_n(output_ids_data, output_len, seq.tokens.data());

// Save random state in host memory
Expand Down
Loading

0 comments on commit 0a1f65a

Please sign in to comment.