Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Clear empty thread when graph destroy" #7860

Merged
merged 1 commit into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ jobs:
working-directory: ${{ env.ONEFLOW_SRC }}
run: |
docker run -d --rm --privileged --shm-size=8g \
--pids-limit 1000 \
--pids-limit -1 \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--runtime=nvidia \
-v /dataset:/dataset:ro -v /model_zoo:/model_zoo:ro \
Expand Down Expand Up @@ -691,14 +691,6 @@ jobs:
EXTRA_DOCKER_ARGS+=" --env ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE=1"
EXTRA_DOCKER_ARGS+=" --env ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER=1"
echo "EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}" >> $GITHUB_ENV
- name: Set Thread Limit (CPU)
if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cpu' }}
run: |
echo "THREAD_LIMIT=8000" >> $GITHUB_ENV
- name: Set Thread Limit (CUDA)
if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cuda' }}
run: |
echo "THREAD_LIMIT=3000" >> $GITHUB_ENV
- name: Enable ONEFLOW_TEST_VERBOSE
if: ${{ contains(github.event.pull_request.labels.*.name, 'need-test-verbose') }}
run: |
Expand All @@ -718,7 +710,6 @@ jobs:
working-directory: ${{ env.ONEFLOW_SRC }}
run: |
docker run -d --rm --privileged --shm-size=8g \
--pids-limit ${{ env.THREAD_LIMIT }} \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--runtime=nvidia \
-v /dataset:/dataset:ro -v /model_zoo:/model_zoo:ro \
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/framework/multi_client_session_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ Maybe<void> MultiClientSessionContext::TryInit(const std::string& config_proto_s
}

Maybe<void> MultiClientSessionContext::UpdateResource(const Resource& reso_proto) {
CHECK_OR_RETURN(is_inited_);
CHECK_NOTNULL_OR_RETURN((Global<ResourceDesc, ForSession>::Get()));
Global<ResourceDesc, ForSession>::Get()->Update(reso_proto);
return Maybe<void>::Ok();
Expand Down
39 changes: 19 additions & 20 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,9 @@ NNGraph::~NNGraph() {
Maybe<void> NNGraph::Close() {
if (!is_closed_) {
VLOG(1) << "Try to close c nn graph name " << name_ << "." << std::endl;
if (runtime_inited_) {
CloseRuntimeBuffers();
runtime_.reset();
}
CloseRuntimeBuffers();
runtime_.reset();
session_ctx_->RemoveGraphFreeEagerTensors(name_);
is_closed_ = true;
VLOG(1) << "Finish close c nn graph name " << name_ << "." << std::endl;

session_ctx_.reset();
Expand Down Expand Up @@ -431,23 +428,25 @@ void NNGraph::NewRuntimeBuffers() {
}

void NNGraph::CloseRuntimeBuffers() {
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();
for (const std::string& output_op_name : outputs_op_names_) {
buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close();
if (runtime_inited_) {
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();
for (const std::string& output_op_name : outputs_op_names_) {
buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close();
}
for (const std::string& input_op_name : inputs_op_names_) {
buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close();
}
buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close();
}
for (const std::string& input_op_name : inputs_op_names_) {
buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close();
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close();
buffer_mgr->Get(GetSourceTickBufferName(name_))->Close();
}
buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close();
}
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close();
buffer_mgr->Get(GetSourceTickBufferName(name_))->Close();
}
}

Expand Down
15 changes: 1 addition & 14 deletions oneflow/core/job/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ bool HasNonCtrlConsumedRegstDescId(const TaskProto& task) {
} // namespace

Runtime::Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob) {
DumpThreadIdsFromPlan(plan);
{
// NOTE(chengcheng): All runtime Global objects AddPlan
Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob);
Global<ThreadMgr>::Get()->AddThreads(thread_ids_);
Global<ThreadMgr>::Get()->AddPlan(plan);
Global<RuntimeJobDescs>::Get()->AddPlan(plan);
collective_boxing_scheduler_plan_token_ =
Global<boxing::collective::Scheduler>::Get()->AddPlan(plan);
Expand Down Expand Up @@ -107,19 +106,7 @@ Runtime::~Runtime() {
Global<RuntimeCtx>::Get()->WaitUntilCntEqualZero(GetRunningActorCountKeyByJobId(pair.first));
}
OF_SESSION_BARRIER();
Global<ThreadMgr>::Get()->TryDeleteThreads(thread_ids_);
Global<boxing::collective::Scheduler>::Get()->DeletePlan(collective_boxing_scheduler_plan_token_);
}

void Runtime::DumpThreadIdsFromPlan(const Plan& plan) {
const int64_t this_rank = GlobalProcessCtx::Rank();
for (const TaskProto& task : plan.task()) {
TaskId task_id = DecodeTaskIdFromInt64(task.task_id());
StreamId stream_id = task_id.stream_id();
if (stream_id.rank() != this_rank) { continue; }
int64_t thrd_id = EncodeStreamIdToInt64(stream_id);
thread_ids_.insert(thrd_id);
}
}

} // namespace oneflow
3 changes: 0 additions & 3 deletions oneflow/core/job/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ class Runtime final {
Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob);

private:
void DumpThreadIdsFromPlan(const Plan& plan);

HashMap<int64_t, int64_t> job_id2actor_size_;
HashSet<int64_t> thread_ids_;

boxing::collective::SchedulerPlanToken* collective_boxing_scheduler_plan_token_;
};
Expand Down
2 changes: 0 additions & 2 deletions oneflow/core/thread/thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ void Thread::AddTask(const TaskProto& task) {
CHECK(id2task_.emplace(task.task_id(), task).second);
}

bool Thread::Empty() const { return id2actor_ptr_.empty(); }

void Thread::PollMsgChannel() {
while (true) {
if (local_msg_queue_.empty()) {
Expand Down
2 changes: 0 additions & 2 deletions oneflow/core/thread/thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class Thread {
virtual ~Thread();

void AddTask(const TaskProto&);
// NOTE(chengcheng): Indicates whether all actors in the thread have been destructed.
bool Empty() const;

Channel<ActorMsg>* GetMsgChannelPtr() { return &msg_channel_; }

Expand Down
47 changes: 13 additions & 34 deletions oneflow/core/thread/thread_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@ limitations under the License.
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/global_for.h"

namespace oneflow {

ThreadMgr::~ThreadMgr() {
CHECK(threads_.empty()) << " Runtime Error! num = " << threads_.size()
<< " threads did not destroy with graph.";
for (auto& thread_pair : threads_) {
ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
thread_pair.second->GetMsgChannelPtr()->Send(msg);
thread_pair.second.reset();
VLOG(3) << "actor thread " << thread_pair.first << " finish";
}
}

Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
Expand All @@ -31,46 +36,20 @@ Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
return iter->second.get();
}

void ThreadMgr::AddThreads(const HashSet<int64_t>& thread_ids) {
void ThreadMgr::AddPlan(const Plan& plan) {
const int64_t this_rank = GlobalProcessCtx::Rank();
for (int64_t thrd_id : thread_ids) {
const auto& it = threads_.find(thrd_id);
if (it != threads_.end()) {
// NOTE(chengcheng): check thread is not null.
CHECK(it->second) << " Runtime Error! Thread: " << thrd_id << " in manager must be NOT null.";
continue;
}
StreamId stream_id = DecodeStreamIdFromInt64(thrd_id);
for (const TaskProto& task : plan.task()) {
TaskId task_id = DecodeTaskIdFromInt64(task.task_id());
StreamId stream_id = task_id.stream_id();
if (stream_id.rank() != this_rank) { continue; }
int64_t thrd_id = EncodeStreamIdToInt64(stream_id);
if (threads_.find(thrd_id) != threads_.end()) { continue; }
Thread* thread = new Thread(stream_id);
CHECK_NOTNULL(thread);
threads_[thrd_id].reset(thread);
}
}

void ThreadMgr::TryDeleteThreads(const HashSet<int64_t>& thread_ids) {
std::unique_lock<std::mutex> lock(mutex4del_threads_);
for (int64_t thrd_id : thread_ids) {
const auto& it = threads_.find(thrd_id);
if (it == threads_.end()) { continue; }
auto& thread = it->second;
CHECK(thread) << " actor thread " << thrd_id << " non-existent but want to delete";
if (thread->Empty()) {
// NOTE(chengcheng): Only delete thread when it is empty.
// We need send Stop msg to exit the main loop of the thread. Here we can safely call reset
// directly, because the thread destructor will specify actor_thread_.join() to blocking
// wait for the end of the thread real execution.
ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
thread->GetMsgChannelPtr()->Send(msg);
thread.reset();
VLOG(2) << " actor thread " << thrd_id << " finish.";
threads_.erase(it);
} else {
LOG(INFO) << " actor thread " << thrd_id << " not delete because it is not empty.";
}
}
}

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
FOR_RANGE(size_t, i, 0, num) { Callback(i); }
}
Expand Down
5 changes: 1 addition & 4 deletions oneflow/core/thread/thread_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_
#define ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_

#include <mutex>
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/auto_registration_factory.h"
Expand All @@ -37,15 +36,13 @@ class ThreadMgr final {
ThreadMgr() = default;
~ThreadMgr();

void AddThreads(const HashSet<int64_t>& thread_ids);
void TryDeleteThreads(const HashSet<int64_t>& thread_ids);
void AddPlan(const Plan& plan);
Thread* GetThrd(int64_t thrd_id);

private:
friend class Global<ThreadMgr>;

HashMap<int64_t, std::unique_ptr<Thread>> threads_;
std::mutex mutex4del_threads_;
};

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback);
Expand Down