Skip to content

Commit

Permalink
Merge branch 'smalton/INSTX-1180-missing-bases' into 'master'
Browse files Browse the repository at this point in the history
INSTX-1180: Fix race condition

Closes INSTX-1180

See merge request machine-learning/dorado!472
  • Loading branch information
malton-ont committed Jul 7, 2023
2 parents ddb6f71 + cbd67d5 commit 4d8ca17
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
24 changes: 12 additions & 12 deletions dorado/nn/CudaCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ class CudaCaller {
}

struct NNTask {
NNTask(torch::Tensor input_, int num_chunks_) : input(input_), num_chunks(num_chunks_) {}
NNTask(torch::Tensor input_, torch::Tensor &output_, int num_chunks_)
: input(input_), out(output_), num_chunks(num_chunks_) {}
torch::Tensor input;
torch::Tensor &out;
std::mutex mut;
std::condition_variable cv;
torch::Tensor out;
bool done{false};
int num_chunks;
};
Expand All @@ -91,19 +92,18 @@ class CudaCaller {
if (num_chunks == 0) {
return std::vector<DecodedChunk>();
}
NNTask task(input.to(m_options.device()), num_chunks);
auto task = std::make_shared<NNTask>(input.to(m_options.device()), output, num_chunks);
{
std::lock_guard<std::mutex> lock(m_input_lock);
m_input_queue.push_front(&task);
m_input_queue.push_front(task);
}
m_input_cv.notify_one();

std::unique_lock lock(task.mut);
while (!task.done) {
task.cv.wait(lock);
std::unique_lock lock(task->mut);
while (!task->done) {
task->cv.wait(lock);
}

output.copy_(task.out);
return m_decoder->cpu_part(output);
}

Expand Down Expand Up @@ -131,7 +131,7 @@ class CudaCaller {
return;
}

NNTask *task = m_input_queue.back();
auto task = m_input_queue.back();
m_input_queue.pop_back();
input_lock.unlock();

Expand All @@ -144,15 +144,15 @@ class CudaCaller {
stats::Timer timer;
auto scores = m_module->forward(task->input);
const auto forward_ms = timer.GetElapsedMS();
task->out = m_decoder->gpu_part(scores, task->num_chunks, m_decoder_options);
task->out.copy_(m_decoder->gpu_part(scores, task->num_chunks, m_decoder_options));
stream.synchronize();
const auto forward_plus_decode_ms = timer.GetElapsedMS();
++m_num_batches_called;
m_model_ms += forward_ms;
m_decode_ms += forward_plus_decode_ms - forward_ms;
task->done = true;
task->cv.notify_one();
task_lock.unlock();
task->cv.notify_one();
}
}
void terminate() { m_terminate.store(true); }
Expand All @@ -174,7 +174,7 @@ class CudaCaller {
torch::nn::ModuleHolder<torch::nn::AnyModule> m_module{nullptr};
size_t m_model_stride;
std::atomic<bool> m_terminate{false};
std::deque<NNTask *> m_input_queue;
std::deque<std::shared_ptr<NNTask>> m_input_queue;
std::mutex m_input_lock;
std::condition_variable m_input_cv;
std::unique_ptr<std::thread> m_cuda_thread;
Expand Down
20 changes: 10 additions & 10 deletions dorado/nn/ModBaseRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ModBaseCaller {
torch::nn::ModuleHolder<torch::nn::AnyModule> module_holder{nullptr};
std::unique_ptr<RemoraScaler> scaler{nullptr};
ModBaseParams params{};
std::deque<ModBaseTask*> input_queue;
std::deque<std::shared_ptr<ModBaseTask>> input_queue;
std::mutex input_lock;
std::condition_variable input_cv;
#if DORADO_GPU_BUILD && !defined(__APPLE__)
Expand Down Expand Up @@ -191,20 +191,20 @@ class ModBaseCaller {
#if DORADO_GPU_BUILD && !defined(__APPLE__)
c10::cuda::OptionalCUDAStreamGuard stream_guard(caller_data->stream);
#endif
ModBaseTask task(input_sigs.to(m_options.device()), input_seqs.to(m_options.device()),
num_chunks);
auto task = std::make_shared<ModBaseTask>(input_sigs.to(m_options.device()),
input_seqs.to(m_options.device()), num_chunks);
{
std::lock_guard<std::mutex> lock(caller_data->input_lock);
caller_data->input_queue.push_front(&task);
caller_data->input_queue.push_front(task);
}
caller_data->input_cv.notify_one();

std::unique_lock lock(task.mut);
while (!task.done) {
task.cv.wait(lock);
std::unique_lock lock(task->mut);
while (!task->done) {
task->cv.wait(lock);
}

return task.out;
return task->out;
}

void modbase_task_thread_fn(size_t model_id) {
Expand All @@ -229,7 +229,7 @@ class ModBaseCaller {
return;
}

ModBaseTask* task = caller_data->input_queue.back();
auto task = caller_data->input_queue.back();
caller_data->input_queue.pop_back();
input_lock.unlock();

Expand All @@ -246,8 +246,8 @@ class ModBaseCaller {
#endif
++m_num_batches_called;
task->done = true;
task->cv.notify_one();
task_lock.unlock();
task->cv.notify_one();
}
}

Expand Down

0 comments on commit 4d8ca17

Please sign in to comment.