Skip to content

Commit

Permalink
Plan separation compile (#9920)
Browse files Browse the repository at this point in the history
## 背景
rank 数多时,master 编译 所有 rank 的 task node
- 顺序编译,速度慢;
- plan 大(可能超过 2G),总体要发送的数据规模可能达到上百G,传输太慢;

所以必须改成每个 rank 独立编译自己的执行计划。

## 测评数据
- 模拟 n
卡的数据并行:Oneflow-Inc/OneTeam#1679 (comment)
- 实测:Oneflow-Inc/OneTeam#1944

## 实现思路总结
Oneflow-Inc/OneTeam#1791

---------

Signed-off-by: daquexian <daquexian566@gmail.com>
Co-authored-by: lixinqi <lixinqi0703106@163.com>
Co-authored-by: ZZK <359521840@qq.com>
Co-authored-by: guo-ran <360112263@qq.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Ping Zhu <58718936+reygu@users.noreply.github.com>
Co-authored-by: Wang Yi <53533850+marigoold@users.noreply.github.com>
Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
Co-authored-by: Yao Chi <later@usopp.net>
Co-authored-by: Houjiang Chen <chenhoujiangcug@gmail.com>
Co-authored-by: Luyang <flowingsun007@163.com>
Co-authored-by: binbinHan <han_binbin@163.com>
Co-authored-by: Shiyuan Shangguan <shiyuan@oneflow.org>
Co-authored-by: yuhao <72971170+howin98@users.noreply.github.com>
Co-authored-by: jackalcooper <jackalcooper@gmail.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Zhimin Yang <76760002+small1945@users.noreply.github.com>
Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
Co-authored-by: Dongche Zhang <zhang2000dc@gmail.com>
Co-authored-by: leaves-zwx <kunta0932@gmail.com>
Co-authored-by: daquexian <daquexian566@gmail.com>
Co-authored-by: Li Xinqi <lixinqi2010@gmail.com>
Co-authored-by: Peihong Liu <mosout@qq.com>
Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>
Co-authored-by: wyg1997 <wangyinggang@foxmail.com>
Co-authored-by: Liang Depeng <liangdepeng@gmail.com>
Co-authored-by: Yu OuYang <xuanjiuye@gmail.com>
Co-authored-by: WangYi <buaawangyi03@gmail.com>
Co-authored-by: rejoicesyc <47683675+rejoicesyc@users.noreply.github.com>
Co-authored-by: songyicheng <int.rejoice@gmail.com>
Co-authored-by: QI JUN <qijun1994@hotmail.com>
Co-authored-by: zhaoyongke <zhaoyongke@yeah.net>
Co-authored-by: JiaKui Hu <hjk1938927583@163.com>
Co-authored-by: cheng cheng <472491134@qq.com>
  • Loading branch information
1 parent f72ebf6 commit ae52678
Show file tree
Hide file tree
Showing 37 changed files with 849 additions and 243 deletions.
19 changes: 19 additions & 0 deletions oneflow/core/common/env_var/env_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ int64_t ThreadLocalEnvInteger();

DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_THRAED_LOCAL_CACHED_SIZE, 128 * 1024);

template<typename env_var>
const std::string& ThreadLocalEnvString();

#define DEFINE_THREAD_LOCAL_ENV_STRING(env_var, default_value) \
struct env_var {}; \
template<> \
inline const std::string& ThreadLocalEnvString<env_var>() { \
thread_local std::string value = GetStringFromEnv(OF_PP_STRINGIZE(env_var), default_value); \
return value; \
}

DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE, false);
// Default compilation mode during graph compilation. There 2 modes to choose:
// "naive", master rank compile the full plan.
// "rank_per_process", multi process(rank) run seperation compile.
DEFINE_THREAD_LOCAL_ENV_STRING(ONEFLOW_LAZY_COMPILE_MODE, "naive");
// Default number of threads during graph compilation.
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_LAZY_COMPILE_RPC_THREAD_NUM, 16);

} // namespace oneflow

#endif // ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_
347 changes: 272 additions & 75 deletions oneflow/core/framework/nn_graph.cpp

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion oneflow/core/framework/nn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_
#define ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_

#include <memory>
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/framework/op_expr.h"
Expand Down Expand Up @@ -107,6 +106,10 @@ class NNGraph final : public NNGraphIf {
std::vector<std::shared_ptr<one::UserOpExpr>> cached_op_exprs;

private:
// Compile the full task graph for all ranks and then broadcast to all ranks.
Maybe<void> NaiveCompile();
// Each rank compile it's task graph.
Maybe<void> MasterAndWorkerRanksCompile();
Maybe<void> RegisterFreeEagerTensorsToVariableOpNames();
Maybe<void> RegisterNewVariableOpInJobPass();
Maybe<void> DeleteOutdatedVariableInVariableTensorMgr();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ void MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc&
}

inline std::string NewUniqueIdGbc() {
// The boxing task graph is built on rank 0 and broadcasted to all the ranks,
// so the ids here are unique among all the ranks.
static std::atomic<int64_t> counter(0);
static std::atomic<int64_t> curr_job_id(0);
if (curr_job_id != GlobalJobDesc().job_id()) {
Expand Down
13 changes: 12 additions & 1 deletion oneflow/core/graph/compute_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/fake_consumed_regst_provider.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/job/compile_mode.h"

namespace oneflow {

Expand Down Expand Up @@ -52,7 +53,9 @@ class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider {
std::shared_ptr<const Operator> op() const { return op_node_->shared_op(); }

ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override {
return &ExecNode::InferBlobDescsByInputs;
// For default compilation mode, compute task node use input blob desc to infer output blob
// desc; For separate compilation mode, compute task node use NdSBP to infer output blob desc.
return InferBlobDescsMethodGetter::Visit(CHECK_JUST(CurrentCompileMode()));
}

protected:
Expand All @@ -64,6 +67,14 @@ class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider {
void InferProducedDataRegstTimeShape() override;

private:
struct InferBlobDescsMethodGetter final : public CompileModeVisitor<InferBlobDescsMethodGetter> {
static ExecNode::InferBlobDescsMethod VisitNaive() { return &ExecNode::InferBlobDescsByInputs; }
static ExecNode::InferBlobDescsMethod VisitRankPerProcess() {
return &ExecNode::InferBlobDescsByNdSbp;
}
static ExecNode::InferBlobDescsMethod VisitInValid() { return nullptr; }
};

ParallelContext parallel_ctx_;
const OpNode* op_node_;
HashSet<std::string> fake_consumed_regst_names_;
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/graph/copy_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/graph/task_stream_id.h"
#include "oneflow/core/graph/boxing_task_graph.pb.h"
#include "oneflow/core/framework/user_op_registry_manager.h"

namespace oneflow {
Expand Down Expand Up @@ -95,7 +96,8 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
} else {
LOG(FATAL) << "unknow copy type: " << copy_type_;
}
conf.set_name(std::string(copy_type_name) + "_" + NewUniqueId());
conf.set_name(std::string(copy_type_name) + "_" + lbi().op_name() + "-" + lbi().blob_name() + "_"
+ std::to_string(task_id()));
*conf.mutable_user_conf()->mutable_op_type_name() = copy_type_name;
auto in_regst = GetSoleConsumedRegst("copy_in");
CHECK_EQ(in_regst->NumOfLbi(), 1);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/exec_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) {
nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp));
}
CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()),
std::stringstream() << " infer blob descs if failed, op name " << op_->op_loc());
std::stringstream() << " infer blob descs is failed, op name " << op_->op_loc());
if (op_node != nullptr && parallel_ctx->parallel_num() > 1 && nd_sbp_signature != nullptr) {
CHECK_JUST(CheckPhysicalBlobDesc(
*op(), op()->output_bns(),
Expand Down
12 changes: 8 additions & 4 deletions oneflow/core/graph/plan_task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/graph/plan_task_graph.h"

namespace oneflow {
Expand All @@ -35,15 +36,18 @@ void PlanTaskGraph::InitEdges() {
PlanTaskNode* producer_node = task_id_and_plan_task_node.second;
for (const auto& pair : producer_node->task_proto()->produced_regst_desc()) {
for (int64_t consumer_task_id : pair.second.consumer_task_id()) {
PlanTaskNode* consumer_node = task_id2plan_task_node_.at(consumer_task_id);
Connect(producer_node, NewEdge(), consumer_node);
PlanTaskNode* consumer_node = CHECK_JUST(MapAt(task_id2plan_task_node_, consumer_task_id));
TryConnect(producer_node, consumer_node);
}
}
}
}

const TaskProto* PlanTaskGraph::TaskProto4TaskId(int64_t task_id) const {
return task_id2plan_task_node_.at(task_id)->task_proto();
void PlanTaskGraph::TryConnect(PlanTaskNode* src, PlanTaskNode* dst) {
if (edges_.insert({src, dst}).second) { Connect(src, NewEdge(), dst); }
}

const TaskProto* PlanTaskGraph::TaskProto4TaskId(int64_t task_id) const {
return CHECK_JUST(MapAt(task_id2plan_task_node_, task_id))->task_proto();
}
} // namespace oneflow
9 changes: 5 additions & 4 deletions oneflow/core/graph/plan_task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,24 @@ class PlanTaskNode final : public Node<PlanTaskNode, PlanTaskEdge> {
const TaskProto* task_proto_;
};

class PlanTaskGraph final : public Graph<const PlanTaskNode, PlanTaskEdge> {
class PlanTaskGraph : public Graph<const PlanTaskNode, PlanTaskEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(PlanTaskGraph);
explicit PlanTaskGraph(const Plan& plan);
~PlanTaskGraph() = default;
virtual ~PlanTaskGraph() = default;

const TaskProto* TaskProto4TaskId(int64_t task_id) const;
const Plan& plan() const { return *plan_; }

private:
protected:
void InitNodes();
void InitEdges();
void TryConnect(PlanTaskNode* src, PlanTaskNode* dst);

const Plan* plan_;
HashMap<int64_t, PlanTaskNode*> task_id2plan_task_node_;
HashSet<std::pair<PlanTaskNode*, PlanTaskNode*>> edges_;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_
3 changes: 2 additions & 1 deletion oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"

#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/task_graph_rebuild_ctx.h"

namespace oneflow {
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/graph/task_id_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void TaskIdGenerator::TryUpdateTaskIndex(const HashMap<int64_t, uint32_t>& task_
}

TaskId TaskIdGenerator::Generate(const StreamId& stream_id) {
std::unique_lock<std::mutex> lock(mutex_);
if (stream_id2task_index_counter_.count(stream_id) == 0) {
uint32_t init_task_index = 0;
const int64_t i64_stream_id = EncodeStreamIdToInt64(stream_id);
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/graph/task_id_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TaskIdGenerator final {
void TryUpdateTaskIndex(const HashMap<int64_t, uint32_t>& task_index_state);

private:
std::mutex mutex_;
HashMap<StreamId, task_index_t> stream_id2task_index_counter_;
// The task_index_init_state is used to initialize the `stream_id2task_index_counter_` hashmap.
HashMap<int64_t, uint32_t> task_index_init_state_{};
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/graph/task_stream_index_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ limitations under the License.

namespace oneflow {

class TaskStreamIndexManager final {
class TaskStreamIndexManager {
public:
using stream_index_t = StreamId::stream_index_t;

OF_DISALLOW_COPY_AND_MOVE(TaskStreamIndexManager);
TaskStreamIndexManager() = default;
~TaskStreamIndexManager() = default;
virtual ~TaskStreamIndexManager() = default;

StreamIndexGenerator* GetGenerator(const DeviceId& device_id);
stream_index_t GetTaskStreamIndex(TaskType task_type, const DeviceId& device_id);
Expand Down
64 changes: 64 additions & 0 deletions oneflow/core/job/compile_mode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job/compile_mode.h"
#include "oneflow/core/common/env_var/env_var.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/container_util.h"

namespace oneflow {

namespace {

struct CompileModeName final : public CompileModeVisitor<CompileModeName> {
static std::string VisitNaive() { return "naive"; }
static std::string VisitRankPerProcess() { return "rank_per_process"; }
static std::string VisitInValid() { return "invalid"; }
};

std::unordered_map<std::string, CompileMode> Name2CompileMode() {
std::unordered_map<std::string, CompileMode> name2compile_mode;
for (int i = static_cast<int>(CompileMode::kInvalid) + 1;
i != static_cast<int>(CompileMode::kEnd); ++i) {
CompileMode compile_mode = static_cast<CompileMode>(i);
CHECK(name2compile_mode.emplace(CompileModeName::Visit(compile_mode), compile_mode).second);
}
return name2compile_mode;
}

std::string GetValidCompileModeNames() {
std::stringstream ss;
for (int i = static_cast<int>(CompileMode::kInvalid) + 1;
i != static_cast<int>(CompileMode::kEnd); ++i) {
if (i > static_cast<int>(CompileMode::kInvalid) + 1) { ss << ", "; }
CompileMode compile_mode = static_cast<CompileMode>(i);
ss << CompileModeName::Visit(compile_mode);
}
return ss.str();
}

} // namespace

Maybe<CompileMode> CurrentCompileMode() {
static thread_local CompileMode mode =
JUST_MSG(MapAt(Name2CompileMode(), ThreadLocalEnvString<ONEFLOW_LAZY_COMPILE_MODE>()),
std::stringstream()
<< "ONEFLOW_LAZY_COMPILER(value: "
<< ThreadLocalEnvString<ONEFLOW_LAZY_COMPILE_MODE>()
<< ") is invalid. valid options: \"" << GetValidCompileModeNames() << "\"");
return mode;
}

} // namespace oneflow
50 changes: 50 additions & 0 deletions oneflow/core/job/compile_mode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_JOB_COMPILE_MODE_H_
#define ONEFLOW_CORE_JOB_COMPILE_MODE_H_

#include "oneflow/core/common/maybe.h"

namespace oneflow {

enum class CompileMode {
kInvalid = 0, // make sure kInvalid is the first CompileMode
kNaive,
kRankPerProcess,
kEnd, // make sure kEnd is the last CompileMode
};

template<typename DerivedT>
struct CompileModeVisitor {
template<typename... Args>
static auto Visit(CompileMode compile_mode, Args&&... args) {
switch (compile_mode) {
case CompileMode::kNaive: return DerivedT::VisitNaive(std::forward<Args>(args)...);
case CompileMode::kRankPerProcess:
return DerivedT::VisitRankPerProcess(std::forward<Args>(args)...);
default: {
LOG(FATAL) << "invalid compile mode";
return DerivedT::VisitInValid(std::forward<Args>(args)...);
}
}
}
};

Maybe<CompileMode> CurrentCompileMode();

} // namespace oneflow

#endif // ONEFLOW_CORE_JOB_COMPILE_MODE_H_
30 changes: 27 additions & 3 deletions oneflow/core/job/id_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,38 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/job/id_state.h"

namespace oneflow {

namespace {

constexpr static int64_t kRankLimitShift = 16;
constexpr static int64_t kIdLimitShift = (sizeof(int64_t) * 8 - kRankLimitShift);
static_assert(kIdLimitShift > 0, "");

int64_t AddCurrentRankOffset(int64_t x) {
CHECK_GE(x, 0);
CHECK_LT(x, (static_cast<int64_t>(1) << kIdLimitShift));
return (static_cast<int64_t>(GlobalProcessCtx::Rank()) << kIdLimitShift) + x;
}

} // namespace

IDMgr::IDMgr() {
regst_desc_id_count_ = 0;
mem_block_id_count_ = 0;
chunk_id_count_ = 0;
CHECK_LE(GlobalProcessCtx::WorldSize(), (static_cast<int64_t>(1) << kRankLimitShift));
}

int64_t IDMgr::NewRegstDescId() { return AddCurrentRankOffset(regst_desc_id_count_++); }

int64_t IDMgr::NewMemBlockId() { return AddCurrentRankOffset(mem_block_id_count_++); }

int64_t IDMgr::NewChunkId() { return AddCurrentRankOffset(chunk_id_count_++); }
void IDMgr::SaveIdAndTaskIndex(IdState* id_state) {
id_state->regst_desc_id_state_ = regst_desc_id_count_;
id_state->mem_block_id_state_ = mem_block_id_count_;
Expand All @@ -33,9 +54,12 @@ void IDMgr::SaveIdAndTaskIndex(IdState* id_state) {
}

void IDMgr::TryUpdateIdAndTaskIndex(const IdState* id_state) {
regst_desc_id_count_ = std::max(regst_desc_id_count_, id_state->regst_desc_id_state_);
mem_block_id_count_ = std::max(mem_block_id_count_, id_state->mem_block_id_state_);
chunk_id_count_ = std::max(chunk_id_count_, id_state->chunk_id_state_);
regst_desc_id_count_ = std::max(regst_desc_id_count_.load(std::memory_order_relaxed),
id_state->regst_desc_id_state_);
mem_block_id_count_ =
std::max(mem_block_id_count_.load(std::memory_order_relaxed), id_state->mem_block_id_state_);
chunk_id_count_ =
std::max(chunk_id_count_.load(std::memory_order_relaxed), id_state->chunk_id_state_);
task_id_gen_.TryUpdateTaskIndex(id_state->task_index_state_);
}

Expand Down
Loading

0 comments on commit ae52678

Please sign in to comment.