diff --git a/oneflow/core/common/env_var/env_var.h b/oneflow/core/common/env_var/env_var.h index e40dddb91f9..8c76ab30364 100644 --- a/oneflow/core/common/env_var/env_var.h +++ b/oneflow/core/common/env_var/env_var.h @@ -70,6 +70,25 @@ int64_t ThreadLocalEnvInteger(); DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_THRAED_LOCAL_CACHED_SIZE, 128 * 1024); +template +const std::string& ThreadLocalEnvString(); + +#define DEFINE_THREAD_LOCAL_ENV_STRING(env_var, default_value) \ + struct env_var {}; \ + template<> \ + inline const std::string& ThreadLocalEnvString() { \ + thread_local std::string value = GetStringFromEnv(OF_PP_STRINGIZE(env_var), default_value); \ + return value; \ + } + +DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE, false); +// Default compilation mode during graph compilation. There 2 modes to choose: +// "naive", master rank compile the full plan. +// "rank_per_process", multi process(rank) run seperation compile. +DEFINE_THREAD_LOCAL_ENV_STRING(ONEFLOW_LAZY_COMPILE_MODE, "naive"); +// Default number of threads during graph compilation. +DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_LAZY_COMPILE_RPC_THREAD_NUM, 16); + } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_ENV_VAR_H_ diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index bbf7d8330d6..e69de8ac522 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -15,7 +15,9 @@ limitations under the License. */ #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/cost_util.h" @@ -32,6 +34,8 @@ limitations under the License. #include "oneflow/core/functional/functional.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/compiler.h" +#include "oneflow/core/job/rank_compiler.h" +#include "oneflow/core/graph/task_graph.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_instance.h" @@ -41,10 +45,15 @@ limitations under the License. #include "oneflow/core/job/utils/progress_bar.h" #include "oneflow/core/job_rewriter/job_completer.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/virtual_machine.h" +#include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/framework/variable_tensor_mgr.h" +#include "oneflow/core/common/env_var/env_var.h" +#include "oneflow/core/job/compile_mode.h" +#include "oneflow/core/thread/thread_manager.h" namespace oneflow { @@ -326,6 +335,267 @@ Maybe NNGraph::DeleteOutdatedVariableInVariableTensorMgr() { return Maybe::Ok(); } +namespace { + +// A templated function that broadcasts data from the master process to worker processes in a +// multi-threaded manner. Return push/pull keys only in master process. +template +std::set MultiThreadBroadcastFromMasterToWorkers(size_t world_size, + const std::string& prefix, + const X& master_data, + Y* worker_data) { + const size_t thread_num = ThreadLocalEnvInteger(); + const size_t split_num = std::sqrt(world_size); + BalancedSplitter bs(world_size, split_num); + std::set keys; + if (GlobalProcessCtx::IsThisProcessMaster()) { + std::mutex mtx4keys; + std::string data; + master_data.SerializeToString(&data); + MultiThreadLoop( + split_num, + [&](int i) { + std::string key = prefix + std::to_string(i); + Singleton::Get()->PushKV(key, data); + std::lock_guard lock(mtx4keys); + CHECK(keys.insert(key).second); + }, + thread_num); + } else { + const int64_t bs_index = bs.GetRangeIndexForVal(GlobalProcessCtx::Rank()); + std::string key = prefix + std::to_string(bs_index); + Singleton::Get()->PullKV(key, worker_data); + } + return keys; +} + +// A templated function that pushes data from the master process to each worker process using the +// control client. The function takes as input a prefix for the key used to store the data in the +// control client, a pointer to the data to be pushed, and a callable object PrepareEach that +// preprocesses the worker's data. Return push/pull keys only in master process. +template +std::set MultiThreadPushFromMasterToWorkers(const std::string& prefix, T* data, + const PrepareEachT& PrepareEach) { + const size_t thread_num = ThreadLocalEnvInteger(); + constexpr int kWorkerStartRank = 1; + std::set keys{}; + if (GlobalProcessCtx::IsThisProcessMaster()) { + std::mutex mtx4keys; + MultiThreadLoop( + GlobalProcessCtx::WorldSize(), + [&](int i) { + if (i < kWorkerStartRank) { return; } + T worker_data; + std::string key = prefix + std::to_string(i); + PrepareEach(&worker_data, i); + Singleton::Get()->PushKV(key, worker_data); + std::lock_guard lock(mtx4keys); + CHECK(keys.emplace(key).second) << "redundant pull key: " << key; + }, + thread_num); + } else { + Singleton::Get()->PullKV(prefix + std::to_string(GlobalProcessCtx::Rank()), data); + } + return keys; +} + +void DumpCalculationPassName(Job* job) { + for (int i = 0; i < job->net().op_size(); ++i) { + auto* op_conf = job->mutable_net()->mutable_op(i); + if (op_conf->has_scope_symbol_id()) { + const auto& scope = Singleton>::Get()->Get(op_conf->scope_symbol_id()); + op_conf->set_calculation_pass_name(scope.scope_proto().calculation_pass_name()); + } + } +} + +} // namespace + +// The main logic of separation plan compilation. Each rank (process) compile it's related task +// nodes. This can reduce plan compile time and avoid transport large plan protobuf. +// When master compile the full plan, some plan protos are much larger than 1GB, but protobuf has +// 2GB limitation and larg files are slow to transport. So we mush do separatioin plan compile when +// total rank num is large. +// Separation plan compilation is done by: +// a. Master broadcast job(or logical graph) to all workers, make all rank use the same job. +// b. Mater compile BoxingTaskGraph and broadcast it to all workers. BoxingTaskGraph needs to be +// done on master rank. +// c. Each rank compile it's related task node with RankCompiler. RankCompiler compile with the +// BoxingTaskGraph and the job. +Maybe NNGraph::MasterAndWorkerRanksCompile() { + // Seperation compile mode only works with nccl use compute stream and logical chain. + CHECK_OR_RETURN(EnableLogicalChain()) + << Error::RuntimeError() + << "nn.Graph separete compilation needs to work with logical chain enabled."; + // Note that nccl use compute stream mode has not need to generate CollectiveBoxingPlan. + CHECK_OR_RETURN((Singleton::Get()->nccl_use_compute_stream())) + << Error::RuntimeError() + << "nn.Graph separete compilation needs to work with nccl using compute stream enabled."; + + std::set push_pull_keys{}; + const auto& MergeCommKeys = [&](std::set&& keys) { + push_pull_keys.insert(keys.begin(), keys.end()); + }; + if (GlobalProcessCtx::IsThisProcessMaster()) { DumpCalculationPassName(&job_); } + + // a. Master broadcast job(or logical graph) to all workers, make all rank use the same job. + const size_t world_size = GlobalProcessCtx::WorldSize(); + MergeCommKeys(MultiThreadBroadcastFromMasterToWorkers( + world_size, name_ + std::string(__FUNCTION__) + "_job", job_, &job_)); + OpGraphSingletonGuard op_graph_guard(job_); + size_t rank = GlobalProcessCtx::Rank(); + + // b. Mater compile BoxingTaskGraph and broadcast it to all workers. BoxingTaskGraph needs to be + // done on master rank. + auto boxing_task_graph_proto = std::make_shared(); + std::shared_ptr boxing_task_graph; + if (GlobalProcessCtx::IsThisProcessMaster()) { + const auto& ParallelLoop = [](size_t work_num, const std::function& Work) { + MultiThreadLoop(work_num, Work, -1); + }; + boxing_task_graph = JUST(BoxingTaskGraph::New(ParallelLoop)); + boxing_task_graph->ToProto([](TaskNode*) { return true; }, boxing_task_graph_proto.get()); + if (Singleton::Get()->enable_debug_mode()) { + TeePersistentLogStream::Create("boxing_task_" + name_ + "_plan" + std::to_string(0)) + ->Write(*boxing_task_graph_proto); + } + } + const auto& PrepareWorkerBoxingTaskGraphProto = [&](BoxingTaskGraphProto* proto, int64_t i) { + boxing_task_graph->ToProto( + [i](TaskNode* task_node) { return BoxingTaskGraph::SelectTaskNodeByRank(task_node, i); }, + proto); + if (Singleton::Get()->enable_debug_mode()) { + TeePersistentLogStream::Create("boxing_task_" + name_ + "_plan" + std::to_string(i)) + ->Write(*proto); + } + }; + MergeCommKeys(MultiThreadPushFromMasterToWorkers( + name_ + std::string(__FUNCTION__) + "_boxing_task_graph", boxing_task_graph_proto.get(), + PrepareWorkerBoxingTaskGraphProto)); + + // c. Each rank compile it's related task node with RankCompiler. RankCompiler compile with the + // BoxingTaskGraph and the job. + auto* plan = &plan_; + CHECK_JUST(RankCompiler(boxing_task_graph_proto, rank).Compile(variable_op_names_, &job_, plan)); + PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(plan, variable_op_names_); + + if (Singleton::Get()->enable_debug_mode()) { + TeePersistentLogStream::Create("job_" + name_ + "_plan" + std::to_string(rank))->Write(*plan); + PlanUtil::ToDotFile(*plan, "job_" + name_ + "_plan_" + std::to_string(rank) + ".dot"); + } + PlanUtil::GenRegisterHint(plan); + PlanUtil::DumpCtrlRegstInfoToPlan(plan); + PlanUtil::PlanMemoryLog(&plan_, name_); + if (Singleton::Get()->enable_debug_mode()) { + PlanUtil::GenLightPlan(&plan_, name_, rank); + } + OF_SESSION_BARRIER(); + for (const auto& k : push_pull_keys) { Singleton::Get()->ClearKV(k); } + OF_SESSION_BARRIER(); + return Maybe::Ok(); +} + +// Master compile the full plan. +Maybe NNGraph::NaiveCompile() { + auto compile_tc = std::make_unique>(true, true); + if (GlobalProcessCtx::IsThisProcessMaster()) { + auto sub_compile_tc = std::make_unique>(true, true); + // TODO(chengcheng): new memory reused by chunk + Compiler().Compile(&job_, &plan_); + sub_compile_tc->Count("[PlanCompile]" + name_ + " GenerateBasePlan", 1); + PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(&plan_, variable_op_names_); + sub_compile_tc->Count("[PlanCompile]" + name_ + " GenMemBlockAndChunk", 1); + PlanUtil::GenRegisterHint(&plan_); + sub_compile_tc->Count("[PlanCompile]" + name_ + " GenRegisterHint", 1); + // TODO(chengcheng): test collective boxing for multi-job. + PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_); + // PlanUtil::SetForceInplaceMemBlock(&plan_); NOTE(chengcheng): only for ssp. + sub_compile_tc->Count("[PlanCompile]" + name_ + " GenCollectiveBoxingPlan", 1); + PlanUtil::DumpCtrlRegstInfoToPlan(&plan_); + sub_compile_tc->Count("[PlanCompile]" + name_ + " DumpCtrlRegstInfoToPlan", 1); + PlanUtil::PlanMemoryLog(&plan_, name_); + if (Singleton::Get()->enable_debug_mode()) { + PlanUtil::GenLightPlan(&plan_, name_); + } + sub_compile_tc->Count("[GraphCompile]" + name_ + " GenMemAndLightPlanLog", 1, true); + } + compile_tc->Count("[GraphCompile]" + name_ + " CompilePlan", 0); + if (GlobalProcessCtx::WorldSize() > 1) { + std::string plan_name = "plan:" + job_name(); + if (GlobalProcessCtx::IsThisProcessMaster()) { + // TODO(chengcheng): split plan for each rank. + Singleton::Get()->PushKV(plan_name, plan_); + } else { + Singleton::Get()->PullKV(plan_name, &plan_); + } + OF_SESSION_BARRIER(); + if (GlobalProcessCtx::IsThisProcessMaster()) { + Singleton::Get()->ClearKV(plan_name); + } + } + compile_tc->Count("[GraphCompile]" + name_ + " SyncPlan", 0, true); + return Maybe::Ok(); +} + +// There are four plan compilation modes, with the first mode "master compilation" (default) and the +// fourth mode "rank separation compilation" being the ones actually used. +Maybe NNGraph::CompilePlanForRuntime() { + // A global variable to get graph configurations. + auto current_graph_config = std::make_unique(job_.job_conf(), job_id()); + auto compile_tc = std::make_unique>(true, true); + typedef Maybe (NNGraph::*CompileMethodT)(); + struct GetCompileMethod final : public CompileModeVisitor { + static CompileMethodT VisitNaive() { + // Master rank compile the full plan. + return &NNGraph::NaiveCompile; + } + static CompileMethodT VisitRankPerProcess() { + // Multi process(rank) run seperation compile. + return &NNGraph::MasterAndWorkerRanksCompile; + } + static CompileMethodT VisitInValid() { return nullptr; } + }; + JUST((this->*GetCompileMethod::Visit(JUST(CurrentCompileMode())))()); + compile_tc->Count("[GraphCompile]" + name_ + " CompileAndSyncPlan", 0); + PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table()); + compile_tc->Count("[GraphCompile]" + name_ + " PopulateOpAttribute", 0); + return Maybe::Ok(); +} + +Maybe NNGraph::InitRuntime() { + CHECK_OR_RETURN(!runtime_inited_) + << Error::RuntimeError() << "nn.Graph runtime is already initialized"; + + auto compile_tc = std::make_unique>(true, true); + NewRuntimeBuffers(); + + JUST(GetVariableRealBlobAfterSyncPlan()); + + // NOTE(strint): Do memory shrink to free cached memory in eager VM before graph runtime init. + JUST(vm::CurrentRankSync()); + auto* vm = JUST(SingletonMaybe()); + JUST(vm->ShrinkAllMem()); + + if (Singleton::Get()->enable_debug_mode()) { + auto cur_rank = GlobalProcessCtx::Rank(); + auto plan_name = "job_" + name_ + "_plan"; + if (JUST(CurrentCompileMode()) != CompileMode::kNaive) { + plan_name += std::to_string(cur_rank); + } + if (cur_rank == 0 || JUST(CurrentCompileMode()) != CompileMode::kNaive) { + TeePersistentLogStream::Create(plan_name)->Write(plan_); + PlanUtil::ToDotFile(plan_, plan_name + ".dot"); + } + } + + runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_)); + compile_tc->Count("[GraphCompile]" + name_ + " InitRuntime", 0, true); + JUST(LogProgress("[GraphCompile]" + name_ + " Done", true)); + + runtime_inited_ = true; + return Maybe::Ok(); +} + Maybe NNGraph::AlignStatesAfterLogicalGraphCompile() { auto compile_tc = std::make_unique>(true, true); JUST(RegisterFreeEagerTensorsToVariableOpNames()); @@ -379,9 +649,8 @@ Maybe NNGraph::BuildWithNewInputFromSharedGraph( CHECK_EQ_OR_RETURN(new_build_original_job.net().op_size(), shared_op_names_from_ordered_original_graph.size()) << "nn.Graph " << name_ - << " new_build_original_job op size and shared_op_names_from_ordered_original_graph size are " - "not " - "equal."; + << " new_build_original_job op size and shared_op_names_from_ordered_original_graph " + << "size are not equal."; HashMap shared_op_name2_new_op; for (int64_t op_idx = 0; op_idx < shared_op_names_from_ordered_original_graph.size(); ++op_idx) { // Assume that the new graph and the shared graph from nn.Graph.build have the same op order. @@ -406,78 +675,6 @@ Maybe NNGraph::BuildWithNewInputFromSharedGraph( return Maybe::Ok(); } -Maybe NNGraph::CompilePlanForRuntime() { - auto compile_tc = std::make_unique>(true, true); - // A global variable to get graph configurations. - auto current_graph_config = std::make_unique(job_.job_conf(), job_id()); - if (GlobalProcessCtx::IsThisProcessMaster()) { - // TODO(chengcheng): new memory reused by chunk - Compiler().Compile(&job_, &plan_); - auto sub_compile_tc = std::make_unique>(true, true); - PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(&plan_, variable_op_names_); - sub_compile_tc->Count("[GraphCompile]" + name_ + " GenMemBlockAndChunk", 1, true); - if (Singleton::Get()->enable_debug_mode()) { - TeePersistentLogStream::Create("job_" + name_ + "_plan")->Write(plan_); - PlanUtil::ToDotFile(plan_, "job_" + name_ + "_plan.dot"); - } - sub_compile_tc->Count("[GraphCompile]" + name_ + " LogPlan", 1, true); - PlanUtil::GenRegisterHint(&plan_); - sub_compile_tc->Count("[GraphCompile]" + name_ + " GenRegisterHint", 1, true); - // TODO(chengcheng): test collective boxing for multi-job. - PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_); - sub_compile_tc->Count("[GraphCompile]" + name_ + " GenCollectiveBoxingPlan", 1, true); - PlanUtil::DumpCtrlRegstInfoToPlan(&plan_); - sub_compile_tc->Count("[GraphCompile]" + name_ + " DumpCtrlRegstInfoToPlan", 1, true); - PlanUtil::PlanMemoryLog(&plan_, name_); - if (Singleton::Get()->enable_debug_mode()) { - PlanUtil::GenLightPlan(&plan_, name_); - } - sub_compile_tc->Count("[GraphCompile]" + name_ + " GenMemAndLightPlanLog", 1, true); - } - compile_tc->Count("[GraphCompile]" + name_ + " CompilePlan", 0); - if (GlobalProcessCtx::WorldSize() > 1) { - std::string plan_name = "plan:" + job_name(); - if (GlobalProcessCtx::IsThisProcessMaster()) { - // TODO(chengcheng): split plan for each rank. - Singleton::Get()->PushKV(plan_name, plan_); - } else { - Singleton::Get()->PullKV(plan_name, &plan_); - } - OF_SESSION_BARRIER(); - // NOTE(zwx): After barrier plan is synchronized between all ranks, - // then it can be cleared for saving mem. - if (GlobalProcessCtx::IsThisProcessMaster()) { - Singleton::Get()->ClearKV(plan_name); - } - } - compile_tc->Count("[GraphCompile]" + name_ + " SyncPlan", 0, true); - // NOTE(chengcheng): recovery op_attr - PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table()); - return Maybe::Ok(); -} - -Maybe NNGraph::InitRuntime() { - CHECK_OR_RETURN(!runtime_inited_) - << Error::RuntimeError() << "nn.Graph runtime is already initialized"; - - auto compile_tc = std::make_unique>(true, true); - NewRuntimeBuffers(); - - JUST(GetVariableRealBlobAfterSyncPlan()); - - // NOTE(strint): Do memory shrink to free cached memory in eager VM before graph runtime init. - JUST(vm::CurrentRankSync()); - auto* vm = JUST(SingletonMaybe()); - JUST(vm->ShrinkAllMem()); - - runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_object_)); - compile_tc->Count("[GraphCompile]" + name_ + " InitRuntime", 0, true); - JUST(LogProgress("[GraphCompile]" + name_ + " Done", true)); - - runtime_inited_ = true; - return Maybe::Ok(); -} - Maybe NNGraph::CompileAndInitRuntime() { JUST(AlignStatesAfterLogicalGraphCompile()); JUST(CompleteLogicalGraphForRuntime()); diff --git a/oneflow/core/framework/nn_graph.h b/oneflow/core/framework/nn_graph.h index bf4d3bf2706..02a357451c5 100644 --- a/oneflow/core/framework/nn_graph.h +++ b/oneflow/core/framework/nn_graph.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_ #define ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_H_ -#include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/framework/op_expr.h" @@ -107,6 +106,10 @@ class NNGraph final : public NNGraphIf { std::vector> cached_op_exprs; private: + // Compile the full task graph for all ranks and then broadcast to all ranks. + Maybe NaiveCompile(); + // Each rank compile it's task graph. + Maybe MasterAndWorkerRanksCompile(); Maybe RegisterFreeEagerTensorsToVariableOpNames(); Maybe RegisterNewVariableOpInJobPass(); Maybe DeleteOutdatedVariableInVariableTensorMgr(); diff --git a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp index 1e0021c0b7c..0562a60f34a 100644 --- a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp +++ b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp @@ -73,6 +73,8 @@ void MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc& } inline std::string NewUniqueIdGbc() { + // The boxing task graph is built on rank 0 and broadcasted to all the ranks, + // so the ids here are unique among all the ranks. static std::atomic counter(0); static std::atomic curr_job_id(0); if (curr_job_id != GlobalJobDesc().job_id()) { diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index 7d6cb03bc47..6ebd4ac4cee 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/fake_consumed_regst_provider.h" #include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/job/compile_mode.h" namespace oneflow { @@ -52,7 +53,9 @@ class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider { std::shared_ptr op() const { return op_node_->shared_op(); } ExecNode::InferBlobDescsMethod GetInferBlobDescsMethod() const override { - return &ExecNode::InferBlobDescsByInputs; + // For default compilation mode, compute task node use input blob desc to infer output blob + // desc; For separate compilation mode, compute task node use NdSBP to infer output blob desc. + return InferBlobDescsMethodGetter::Visit(CHECK_JUST(CurrentCompileMode())); } protected: @@ -64,6 +67,14 @@ class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider { void InferProducedDataRegstTimeShape() override; private: + struct InferBlobDescsMethodGetter final : public CompileModeVisitor { + static ExecNode::InferBlobDescsMethod VisitNaive() { return &ExecNode::InferBlobDescsByInputs; } + static ExecNode::InferBlobDescsMethod VisitRankPerProcess() { + return &ExecNode::InferBlobDescsByNdSbp; + } + static ExecNode::InferBlobDescsMethod VisitInValid() { return nullptr; } + }; + ParallelContext parallel_ctx_; const OpNode* op_node_; HashSet fake_consumed_regst_names_; diff --git a/oneflow/core/graph/copy_task_node.cpp b/oneflow/core/graph/copy_task_node.cpp index 490fb6643f0..64b32e15cd7 100644 --- a/oneflow/core/graph/copy_task_node.cpp +++ b/oneflow/core/graph/copy_task_node.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/graph/task_stream_id.h" +#include "oneflow/core/graph/boxing_task_graph.pb.h" #include "oneflow/core/framework/user_op_registry_manager.h" namespace oneflow { @@ -95,7 +96,8 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() { } else { LOG(FATAL) << "unknow copy type: " << copy_type_; } - conf.set_name(std::string(copy_type_name) + "_" + NewUniqueId()); + conf.set_name(std::string(copy_type_name) + "_" + lbi().op_name() + "-" + lbi().blob_name() + "_" + + std::to_string(task_id())); *conf.mutable_user_conf()->mutable_op_type_name() = copy_type_name; auto in_regst = GetSoleConsumedRegst("copy_in"); CHECK_EQ(in_regst->NumOfLbi(), 1); diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index bf86e7913c0..3211466e076 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -140,7 +140,7 @@ void ExecNode::InferBlobDescsByInputs(const ParallelContext* parallel_ctx) { nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)); } CHECK_JUST_MSG(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, &GlobalJobDesc()), - std::stringstream() << " infer blob descs if failed, op name " << op_->op_loc()); + std::stringstream() << " infer blob descs is failed, op name " << op_->op_loc()); if (op_node != nullptr && parallel_ctx->parallel_num() > 1 && nd_sbp_signature != nullptr) { CHECK_JUST(CheckPhysicalBlobDesc( *op(), op()->output_bns(), diff --git a/oneflow/core/graph/plan_task_graph.cpp b/oneflow/core/graph/plan_task_graph.cpp index b593774290f..48b95c84989 100644 --- a/oneflow/core/graph/plan_task_graph.cpp +++ b/oneflow/core/graph/plan_task_graph.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/container_util.h" #include "oneflow/core/graph/plan_task_graph.h" namespace oneflow { @@ -35,15 +36,18 @@ void PlanTaskGraph::InitEdges() { PlanTaskNode* producer_node = task_id_and_plan_task_node.second; for (const auto& pair : producer_node->task_proto()->produced_regst_desc()) { for (int64_t consumer_task_id : pair.second.consumer_task_id()) { - PlanTaskNode* consumer_node = task_id2plan_task_node_.at(consumer_task_id); - Connect(producer_node, NewEdge(), consumer_node); + PlanTaskNode* consumer_node = CHECK_JUST(MapAt(task_id2plan_task_node_, consumer_task_id)); + TryConnect(producer_node, consumer_node); } } } } -const TaskProto* PlanTaskGraph::TaskProto4TaskId(int64_t task_id) const { - return task_id2plan_task_node_.at(task_id)->task_proto(); +void PlanTaskGraph::TryConnect(PlanTaskNode* src, PlanTaskNode* dst) { + if (edges_.insert({src, dst}).second) { Connect(src, NewEdge(), dst); } } +const TaskProto* PlanTaskGraph::TaskProto4TaskId(int64_t task_id) const { + return CHECK_JUST(MapAt(task_id2plan_task_node_, task_id))->task_proto(); +} } // namespace oneflow diff --git a/oneflow/core/graph/plan_task_graph.h b/oneflow/core/graph/plan_task_graph.h index 67753560935..2fa3c9f3b3f 100644 --- a/oneflow/core/graph/plan_task_graph.h +++ b/oneflow/core/graph/plan_task_graph.h @@ -45,23 +45,24 @@ class PlanTaskNode final : public Node { const TaskProto* task_proto_; }; -class PlanTaskGraph final : public Graph { +class PlanTaskGraph : public Graph { public: OF_DISALLOW_COPY_AND_MOVE(PlanTaskGraph); explicit PlanTaskGraph(const Plan& plan); - ~PlanTaskGraph() = default; + virtual ~PlanTaskGraph() = default; const TaskProto* TaskProto4TaskId(int64_t task_id) const; const Plan& plan() const { return *plan_; } - private: + protected: void InitNodes(); void InitEdges(); + void TryConnect(PlanTaskNode* src, PlanTaskNode* dst); const Plan* plan_; HashMap task_id2plan_task_node_; + HashSet> edges_; }; - } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_PLAN_TASK_GRAPH_H_ diff --git a/oneflow/core/graph/slice_boxing_task_node.cpp b/oneflow/core/graph/slice_boxing_task_node.cpp index 038e1d70868..01be443718f 100644 --- a/oneflow/core/graph/slice_boxing_task_node.cpp +++ b/oneflow/core/graph/slice_boxing_task_node.cpp @@ -13,8 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/slice_boxing_task_node.h" + +#include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" namespace oneflow { diff --git a/oneflow/core/graph/task_id_generator.cpp b/oneflow/core/graph/task_id_generator.cpp index 8747ac931f0..4ed671828b8 100644 --- a/oneflow/core/graph/task_id_generator.cpp +++ b/oneflow/core/graph/task_id_generator.cpp @@ -50,6 +50,7 @@ void TaskIdGenerator::TryUpdateTaskIndex(const HashMap& task_ } TaskId TaskIdGenerator::Generate(const StreamId& stream_id) { + std::unique_lock lock(mutex_); if (stream_id2task_index_counter_.count(stream_id) == 0) { uint32_t init_task_index = 0; const int64_t i64_stream_id = EncodeStreamIdToInt64(stream_id); diff --git a/oneflow/core/graph/task_id_generator.h b/oneflow/core/graph/task_id_generator.h index abee6a57ddc..33e200a0335 100644 --- a/oneflow/core/graph/task_id_generator.h +++ b/oneflow/core/graph/task_id_generator.h @@ -35,6 +35,7 @@ class TaskIdGenerator final { void TryUpdateTaskIndex(const HashMap& task_index_state); private: + std::mutex mutex_; HashMap stream_id2task_index_counter_; // The task_index_init_state is used to initialize the `stream_id2task_index_counter_` hashmap. HashMap task_index_init_state_{}; diff --git a/oneflow/core/graph/task_stream_index_manager.h b/oneflow/core/graph/task_stream_index_manager.h index 34f892c2816..e569dabad86 100644 --- a/oneflow/core/graph/task_stream_index_manager.h +++ b/oneflow/core/graph/task_stream_index_manager.h @@ -21,13 +21,13 @@ limitations under the License. namespace oneflow { -class TaskStreamIndexManager final { +class TaskStreamIndexManager { public: using stream_index_t = StreamId::stream_index_t; OF_DISALLOW_COPY_AND_MOVE(TaskStreamIndexManager); TaskStreamIndexManager() = default; - ~TaskStreamIndexManager() = default; + virtual ~TaskStreamIndexManager() = default; StreamIndexGenerator* GetGenerator(const DeviceId& device_id); stream_index_t GetTaskStreamIndex(TaskType task_type, const DeviceId& device_id); diff --git a/oneflow/core/job/compile_mode.cpp b/oneflow/core/job/compile_mode.cpp new file mode 100644 index 00000000000..21ce131e3cb --- /dev/null +++ b/oneflow/core/job/compile_mode.cpp @@ -0,0 +1,64 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/job/compile_mode.h" +#include "oneflow/core/common/env_var/env_var.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/container_util.h" + +namespace oneflow { + +namespace { + +struct CompileModeName final : public CompileModeVisitor { + static std::string VisitNaive() { return "naive"; } + static std::string VisitRankPerProcess() { return "rank_per_process"; } + static std::string VisitInValid() { return "invalid"; } +}; + +std::unordered_map Name2CompileMode() { + std::unordered_map name2compile_mode; + for (int i = static_cast(CompileMode::kInvalid) + 1; + i != static_cast(CompileMode::kEnd); ++i) { + CompileMode compile_mode = static_cast(i); + CHECK(name2compile_mode.emplace(CompileModeName::Visit(compile_mode), compile_mode).second); + } + return name2compile_mode; +} + +std::string GetValidCompileModeNames() { + std::stringstream ss; + for (int i = static_cast(CompileMode::kInvalid) + 1; + i != static_cast(CompileMode::kEnd); ++i) { + if (i > static_cast(CompileMode::kInvalid) + 1) { ss << ", "; } + CompileMode compile_mode = static_cast(i); + ss << CompileModeName::Visit(compile_mode); + } + return ss.str(); +} + +} // namespace + +Maybe CurrentCompileMode() { + static thread_local CompileMode mode = + JUST_MSG(MapAt(Name2CompileMode(), ThreadLocalEnvString()), + std::stringstream() + << "ONEFLOW_LAZY_COMPILER(value: " + << ThreadLocalEnvString() + << ") is invalid. valid options: \"" << GetValidCompileModeNames() << "\""); + return mode; +} + +} // namespace oneflow diff --git a/oneflow/core/job/compile_mode.h b/oneflow/core/job/compile_mode.h new file mode 100644 index 00000000000..2c88c5de346 --- /dev/null +++ b/oneflow/core/job/compile_mode.h @@ -0,0 +1,50 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_JOB_COMPILE_MODE_H_ +#define ONEFLOW_CORE_JOB_COMPILE_MODE_H_ + +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +enum class CompileMode { + kInvalid = 0, // make sure kInvalid is the first CompileMode + kNaive, + kRankPerProcess, + kEnd, // make sure kEnd is the last CompileMode +}; + +template +struct CompileModeVisitor { + template + static auto Visit(CompileMode compile_mode, Args&&... args) { + switch (compile_mode) { + case CompileMode::kNaive: return DerivedT::VisitNaive(std::forward(args)...); + case CompileMode::kRankPerProcess: + return DerivedT::VisitRankPerProcess(std::forward(args)...); + default: { + LOG(FATAL) << "invalid compile mode"; + return DerivedT::VisitInValid(std::forward(args)...); + } + } + } +}; + +Maybe CurrentCompileMode(); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_COMPILE_MODE_H_ diff --git a/oneflow/core/job/id_manager.cpp b/oneflow/core/job/id_manager.cpp index 3e8651ab037..7dcf4843259 100644 --- a/oneflow/core/job/id_manager.cpp +++ b/oneflow/core/job/id_manager.cpp @@ -14,17 +14,38 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job/id_manager.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/job/id_state.h" namespace oneflow { +namespace { + +constexpr static int64_t kRankLimitShift = 16; +constexpr static int64_t kIdLimitShift = (sizeof(int64_t) * 8 - kRankLimitShift); +static_assert(kIdLimitShift > 0, ""); + +int64_t AddCurrentRankOffset(int64_t x) { + CHECK_GE(x, 0); + CHECK_LT(x, (static_cast(1) << kIdLimitShift)); + return (static_cast(GlobalProcessCtx::Rank()) << kIdLimitShift) + x; +} + +} // namespace + IDMgr::IDMgr() { regst_desc_id_count_ = 0; mem_block_id_count_ = 0; chunk_id_count_ = 0; + CHECK_LE(GlobalProcessCtx::WorldSize(), (static_cast(1) << kRankLimitShift)); } +int64_t IDMgr::NewRegstDescId() { return AddCurrentRankOffset(regst_desc_id_count_++); } + +int64_t IDMgr::NewMemBlockId() { return AddCurrentRankOffset(mem_block_id_count_++); } + +int64_t IDMgr::NewChunkId() { return AddCurrentRankOffset(chunk_id_count_++); } void IDMgr::SaveIdAndTaskIndex(IdState* id_state) { id_state->regst_desc_id_state_ = regst_desc_id_count_; id_state->mem_block_id_state_ = mem_block_id_count_; @@ -33,9 +54,12 @@ void IDMgr::SaveIdAndTaskIndex(IdState* id_state) { } void IDMgr::TryUpdateIdAndTaskIndex(const IdState* id_state) { - regst_desc_id_count_ = std::max(regst_desc_id_count_, id_state->regst_desc_id_state_); - mem_block_id_count_ = std::max(mem_block_id_count_, id_state->mem_block_id_state_); - chunk_id_count_ = std::max(chunk_id_count_, id_state->chunk_id_state_); + regst_desc_id_count_ = std::max(regst_desc_id_count_.load(std::memory_order_relaxed), + id_state->regst_desc_id_state_); + mem_block_id_count_ = + std::max(mem_block_id_count_.load(std::memory_order_relaxed), id_state->mem_block_id_state_); + chunk_id_count_ = + std::max(chunk_id_count_.load(std::memory_order_relaxed), id_state->chunk_id_state_); task_id_gen_.TryUpdateTaskIndex(id_state->task_index_state_); } diff --git a/oneflow/core/job/id_manager.h b/oneflow/core/job/id_manager.h index 38f0ac809bf..4a1b442f96b 100644 --- a/oneflow/core/job/id_manager.h +++ b/oneflow/core/job/id_manager.h @@ -30,9 +30,9 @@ class IDMgr final { OF_DISALLOW_COPY_AND_MOVE(IDMgr); ~IDMgr() = default; - int64_t NewRegstDescId() { return regst_desc_id_count_++; } - int64_t NewMemBlockId() { return mem_block_id_count_++; } - int64_t NewChunkId() { return chunk_id_count_++; } + int64_t NewRegstDescId(); + int64_t NewMemBlockId(); + int64_t NewChunkId(); TaskIdGenerator* GetTaskIdGenerator() { return &task_id_gen_; } @@ -43,9 +43,9 @@ class IDMgr final { friend class Singleton; IDMgr(); - int64_t regst_desc_id_count_; - int64_t mem_block_id_count_; - int64_t chunk_id_count_; + std::atomic regst_desc_id_count_; + std::atomic mem_block_id_count_; + std::atomic chunk_id_count_; TaskIdGenerator task_id_gen_; }; diff --git a/oneflow/core/job/intra_job_mem_sharing_util.cpp b/oneflow/core/job/intra_job_mem_sharing_util.cpp index 3ea1959eb0e..9c8d3f4bfaa 100644 --- a/oneflow/core/job/intra_job_mem_sharing_util.cpp +++ b/oneflow/core/job/intra_job_mem_sharing_util.cpp @@ -171,7 +171,6 @@ void GenMemChainTasksAndRegsts( for (auto& device_chain_pair : device2chain2mem_chain) { if (device_chain_pair.second.empty()) { continue; } - // sort std::vector mem_chains; mem_chains.reserve(device_chain_pair.second.size()); for (auto& pair : device_chain_pair.second) { mem_chains.emplace_back(&pair.second); } diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index ce9bb4b14ed..54a2db5902c 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -212,11 +212,11 @@ message JobSignatureDef { } enum StraightenAlgorithmTag { - kDisableStraighten = 1; - kOverlap4Transfer = 2; + kDisableStraighten = 1; + kOverlap4Transfer = 2; kCompressMemory = 3; kOverlap4CpuGpu = 4; - kDelayShortGpu = 5; + kDelayShortGpu = 5; } enum AutoMemoryStrategy { diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index c4ac9d6e311..bc16c618906 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/plan_util.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/graph/plan_task_graph.h" #include "oneflow/core/graph/boxing/collective_boxing_util.h" @@ -808,11 +809,15 @@ std::function PlanUtil::MakeMutRegstDesc4Id(Plan* plan TaskProto* task = plan->mutable_task(i); for (auto& pair : *task->mutable_produced_regst_desc()) { int64_t regst_desc_id = pair.second.regst_desc_id(); - regst_desc_id2regst_desc->insert({regst_desc_id, &pair.second}); + CHECK(regst_desc_id2regst_desc->insert({regst_desc_id, &pair.second}).second) + << "regst_desc_id2regst_desc has got duplicated regst_desc_id " << regst_desc_id; } } return [regst_desc_id2regst_desc](int64_t regst_desc_id) -> RegstDescProto* { - return regst_desc_id2regst_desc->at(regst_desc_id); + auto iter = regst_desc_id2regst_desc->find(regst_desc_id); + CHECK(iter != regst_desc_id2regst_desc->end()) + << "regst_desc_id " << regst_desc_id << " can't be found in plan."; + return iter->second; }; } @@ -876,7 +881,7 @@ bool IsCollectiveBoxingTaskType(TaskType task_type) { bool IsCollectiveBoxingNode(const PlanTaskNode* node) { const TaskType task_type = node->task_proto()->task_type(); - return task_type == TaskType::kCollectiveBoxingGeneric; + return IsCollectiveBoxingTaskType(task_type); } const boxing::collective::RankDesc& GetRankDesc(const OperatorConf& conf) { @@ -1224,7 +1229,7 @@ void PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) { } } -void PlanUtil::GenLightPlan(Plan* plan, const std::string& plan_name) { +void PlanUtil::GenLightPlan(Plan* plan, const std::string& plan_name, int64_t limited_rank) { // NOTE(chengcheng): ordered_tasks is NOT exec order, just task id order. std::vector ordered_tasks; for (const TaskProto& task : plan->task()) { ordered_tasks.push_back(&task); } @@ -1287,6 +1292,8 @@ void PlanUtil::GenLightPlan(Plan* plan, const std::string& plan_name) { rank2ordered_task.at(task->machine_id()).push_back(task); } for (int64_t rank = 0; rank < GlobalProcessCtx::WorldSize(); ++rank) { + // Filter rank to generate log. + if (limited_rank >= 0 && rank != limited_rank) { continue; } auto file_stream = TeePersistentLogStream::Create(plan_name + "_rank_" + std::to_string(rank) + "_light_plan"); file_stream << "rank : " << std::to_string(rank) << "\n"; diff --git a/oneflow/core/job/plan_util.h b/oneflow/core/job/plan_util.h index c8601ae6682..c6e8ae32749 100644 --- a/oneflow/core/job/plan_util.h +++ b/oneflow/core/job/plan_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/graph/stream_id.h" +#include "oneflow/core/graph/plan_task_graph.h" namespace oneflow { @@ -45,7 +46,10 @@ struct PlanUtil { static void DumpCtrlRegstInfoToPlan(Plan* plan); static void GenCollectiveBoxingPlan(Job* job, Plan* plan); static void GenRegisterHint(Plan* plan); - static void GenLightPlan(Plan* plan, const std::string& plan_name); + // Generate readable plan log from plan proto. + // Use filter_rank to choose which rank to generate. When filter_rank is -1, all rank will be + // generated. The default value of filter_rank is -1. + static void GenLightPlan(Plan* plan, const std::string& plan_name, int64_t limited_rank = -1); static void PlanMemoryLog(Plan* plan, const std::string& plan_name); static const oneflow::OpAttribute& GetOpAttribute(const Plan* plan, int64_t job_id, const oneflow::KernelConf& kernel_conf); diff --git a/oneflow/core/memory/chunk_manager.cpp b/oneflow/core/memory/chunk_manager.cpp index 2943e2760ad..d7173440602 100644 --- a/oneflow/core/memory/chunk_manager.cpp +++ b/oneflow/core/memory/chunk_manager.cpp @@ -22,6 +22,7 @@ namespace oneflow { void ChunkMgr::GetChunkProtosByMemZoneUniqueId(int64_t mem_zone_uid, std::vector* chunks) const { + std::unique_lock guard(mutex_); chunks->clear(); auto chunk_ids_it = mzuid2chunk_ids_.find(mem_zone_uid); if (chunk_ids_it != mzuid2chunk_ids_.end()) { @@ -36,6 +37,7 @@ void ChunkMgr::GetChunkProtosByMemZoneUniqueId(int64_t mem_zone_uid, } void ChunkMgr::AddChunkProto(const ChunkProto& chunk) { + std::unique_lock guard(mutex_); const int64_t mem_zone_uid = memory::GetUniqueMemCaseId(chunk.machine_id(), chunk.mem_case()); CHECK( chunk_id2chunk_proto_.emplace(chunk.chunk_id(), std::make_unique(chunk)).second); @@ -47,6 +49,7 @@ void ChunkMgr::AddChunkProto(const ChunkProto& chunk) { } char* ChunkMgr::FindOrCreateChunk(const ChunkProto& chunk) { + std::unique_lock guard(mutex_); CHECK_EQ(GlobalProcessCtx::Rank(), chunk.machine_id()); auto it = chunk_id2chunk_.find(chunk.chunk_id()); if (it == chunk_id2chunk_.end()) { diff --git a/oneflow/core/memory/chunk_manager.h b/oneflow/core/memory/chunk_manager.h index 1507eef67b5..44970ba70e6 100644 --- a/oneflow/core/memory/chunk_manager.h +++ b/oneflow/core/memory/chunk_manager.h @@ -51,7 +51,7 @@ class ChunkMgr final { // for runtime HashMap chunk_id2chunk_; - std::mutex mutex_; + mutable std::mutex mutex_; }; } // namespace oneflow diff --git a/oneflow/core/operator/collective_boxing_pack_op.cpp b/oneflow/core/operator/collective_boxing_pack_op.cpp index f490f1086ee..0a3ddf62fa5 100644 --- a/oneflow/core/operator/collective_boxing_pack_op.cpp +++ b/oneflow/core/operator/collective_boxing_pack_op.cpp @@ -60,7 +60,7 @@ Maybe CollectiveBoxingPackOp::InferOutBlobDescs( const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - *out_blob_desc = *in_blob_desc; + *CHECK_NOTNULL(out_blob_desc) = *CHECK_NOTNULL(in_blob_desc); // NOLINT out_blob_desc->set_shape(Shape({in_blob_desc->shape().elem_cnt()})); return Maybe::Ok(); } diff --git a/oneflow/core/operator/op_attribute.proto b/oneflow/core/operator/op_attribute.proto index 548bec15790..296bd401be5 100644 --- a/oneflow/core/operator/op_attribute.proto +++ b/oneflow/core/operator/op_attribute.proto @@ -28,7 +28,6 @@ message OpAttribute { optional SbpSignature sbp_signature = 104; optional LocalSignature local_signature = 105; optional BlobDescSignature logical_blob_desc_signature = 106; - optional ParallelSignature parallel_signature = 108; optional ParallelConfSignature parallel_conf_signature = 109; optional NdSbpSignature nd_sbp_signature = 110; } diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 6e5a3ecd4fd..9737496c26f 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -1412,35 +1412,6 @@ Maybe Operator::ToOpAttribute(OpAttribute* op_attribute) const { if (!has_same_parallel_conf_as_op) { (*map)[pair.first] = pair.second->parallel_conf(); } } } - if (op_parallel_desc_ && bn2parallel_desc_) { - if (op_conf().scope_symbol_id() != 0) { - const auto& scope_storage = *Singleton>::Get(); - const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); - int64_t parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf())); - auto* parallel_signature = op_attribute->mutable_parallel_signature(); - parallel_signature->set_op_parallel_desc_symbol_id(parallel_desc_symbol_id); - auto* symbol_map = parallel_signature->mutable_bn_in_op2parallel_desc_symbol_id(); - for (const auto& pair : *bn2parallel_desc_) { - if (*pair.second == *op_parallel_desc_) { - (*symbol_map)[pair.first] = parallel_desc_symbol_id; - } else { - ParallelConf parallel_conf = pair.second->parallel_conf(); - const auto MakeParallelDescSymbol = [¶llel_conf]() -> Maybe { - int64_t symbol_id = 0; - const auto BuildInstruction = - [&symbol_id, ¶llel_conf](InstructionsBuilder* builder) -> Maybe { - symbol_id = JUST(JUST(builder->GetParallelDescSymbol(parallel_conf))->symbol_id()); - return Maybe::Ok(); - }; - JUST(PhysicalRun(BuildInstruction)); - return symbol_id; - }; - (*symbol_map)[pair.first] = JUST(MakeParallelDescSymbol()); - } - } - for (const auto& tbn : tmp_bns()) { (*symbol_map)[tbn] = parallel_desc_symbol_id; } - } - } return Maybe::Ok(); } diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index 1da2e267325..38203e11be5 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -64,7 +64,7 @@ class RegstDesc final { // mem const MemoryCase& mem_case() const { return mem_case_; } MemoryCase* mut_mem_case() { return &mem_case_; } - bool enable_reuse_mem() { return enable_reuse_mem_; } + bool enable_reuse_mem() const { return enable_reuse_mem_; } void set_enable_reuse_mem(bool enable_reuse_mem) { enable_reuse_mem_ = enable_reuse_mem; } int64_t mem_block_offset() const; void set_mem_block_offset(int64_t val) { mem_block_offset_ = val; } diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto index f6af6a9d9d3..2adc22e4bf9 100644 --- a/oneflow/core/register/register_desc.proto +++ b/oneflow/core/register/register_desc.proto @@ -13,7 +13,7 @@ message LbiBlobDescPair { message DataRegstDesc { repeated LbiBlobDescPair lbi2blob_desc = 1; - required ShapeProto time_shape = 3; + optional ShapeProto time_shape = 3; } message CtrlRegstDesc { diff --git a/oneflow/core/rpc/include/global_process_ctx.h b/oneflow/core/rpc/include/global_process_ctx.h index 451ce91070f..00cfc159581 100644 --- a/oneflow/core/rpc/include/global_process_ctx.h +++ b/oneflow/core/rpc/include/global_process_ctx.h @@ -21,6 +21,7 @@ limitations under the License. namespace oneflow { struct GlobalProcessCtx { + static void GetMachineIdAndDeviceId(int64_t rank, int64_t* machine_id, int64_t* device_id); static void GetCurrentMachineIdAndDeviceId(int64_t* machine_id, int64_t* device_id); static int64_t Rank(); static int64_t LocalRank(); diff --git a/oneflow/core/rpc/lib/global_process_ctx.cpp b/oneflow/core/rpc/lib/global_process_ctx.cpp index ab502c3b1b0..1b4c8d25949 100644 --- a/oneflow/core/rpc/lib/global_process_ctx.cpp +++ b/oneflow/core/rpc/lib/global_process_ctx.cpp @@ -20,6 +20,12 @@ limitations under the License. namespace oneflow { +void GlobalProcessCtx::GetMachineIdAndDeviceId(int64_t rank, int64_t* machine_id, + int64_t* device_id) { + *machine_id = rank; + *device_id = rank % NumOfProcessPerNode(); +} + void GlobalProcessCtx::GetCurrentMachineIdAndDeviceId(int64_t* machine_id, int64_t* device_id) { *machine_id = Rank(); *device_id = LocalRank(); diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index bf877279ee5..2b6aca3627c 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -153,6 +153,7 @@ def __init__( self._unique_global_op_dict = dict() self._unique_identity_op_dict = dict() + # Graph compilation related. # forward graph job proto self._forward_job_proto = None # forward, backward and optimized graph job proto @@ -163,6 +164,14 @@ def __init__( self._args_repr = [] self._outs_repr = [] self._oneflow_internal_graph_ir__ = None + enalbe_lazy_separate_compile = os.environ.get( + "ONEFLOW_ENABLE_LAZY_SEPARATE_COMPILE" + ) + if enalbe_lazy_separate_compile != None and enalbe_lazy_separate_compile == "1": + os.environ["ONEFLOW_LAZY_COMPILE_MODE"] = "rank_per_process" + # Separate compile mode only works with nccl use compute stream and logical chain. + os.environ["ENABLE_LOGICAL_CHAIN"] = "1" + oneflow.boxing.nccl.enable_use_compute_stream(True) self._session = session_ctx.GetDefaultSession() assert type(self._session) is MultiClientSession diff --git a/python/oneflow/test/graph/test_alexnet_auto_parallel.py b/python/oneflow/test/graph/test_alexnet_auto_parallel.py index daa82a04458..1d9d16a4779 100644 --- a/python/oneflow/test/graph/test_alexnet_auto_parallel.py +++ b/python/oneflow/test/graph/test_alexnet_auto_parallel.py @@ -179,6 +179,8 @@ def build(self, image): epoch, i, len(train_iter), l, end_t - start_t ) ) + # Stop after 20 iters to save time + break if flow.env.get_rank() == 0: print("epoch %d train done, start validation" % epoch) diff --git a/python/oneflow/test/graph/test_fx_fuse.py b/python/oneflow/test/graph/test_fx_fuse.py index 3c570c883ad..2e93b0bc3ff 100644 --- a/python/oneflow/test/graph/test_fx_fuse.py +++ b/python/oneflow/test/graph/test_fx_fuse.py @@ -23,107 +23,113 @@ from typing import Dict, Any, Tuple -class TestConvBnFuse(flow.unittest.TestCase): - def fuse_conv_bn_eval(conv, bn): - """ - Given a conv Module `A` and an batch_norm module `B`, returns a conv - module `C` such that C(x) == B(A(x)) in inference mode. - """ - assert not (conv.training or bn.training), "Fusion only for eval!" - fused_conv = copy.deepcopy(conv) - - fused_conv.weight, fused_conv.bias = TestConvBnFuse.fuse_conv_bn_weights( - fused_conv.weight, - fused_conv.bias, - bn.running_mean, - bn.running_var, - bn.eps, - bn.weight, - bn.bias, - ) - - return fused_conv - - def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): - if conv_b is None: - conv_b = flow.zeros_like(bn_rm) - if bn_w is None: - bn_w = flow.ones_like(bn_rm) - if bn_b is None: - bn_b = flow.zeros_like(bn_rm) - bn_var_rsqrt = flow.rsqrt(bn_rv + bn_eps) - - conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape( - [-1] + [1] * (len(conv_w.shape) - 1) - ) - conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b - - return flow.nn.Parameter(conv_w), flow.nn.Parameter(conv_b) - - def _parent_name(target: str) -> Tuple[str, str]: - """ - Splits a qualname into parent path and last atom. - For example, `foo.bar.baz` -> (`foo.bar`, `baz`) - """ - *parent, name = target.rsplit(".", 1) - return parent[0] if parent else "", name - - def replace_node_module( - node: flow.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module - ): - assert isinstance(node.target, str) - parent_name, name = TestConvBnFuse._parent_name(node.target) - setattr(modules[parent_name], name, new_module) - - def fuse(model: flow.nn.Module) -> flow.nn.Module: - model = copy.deepcopy(model) - # The first step of most FX passes is to symbolically trace our model to - # obtain a `GraphModule`. This is a representation of our original model - # that is functionally identical to our original model, except that we now - # also have a graph representation of our forward pass. - fx_model: flow.fx.GraphModule = flow.fx.symbolic_trace(model) - modules = dict(fx_model.named_modules()) - - # The primary representation for working with FX are the `Graph` and the - # `Node`. Each `GraphModule` has a `Graph` associated with it - this - # `Graph` is also what generates `GraphModule.code`. - # The `Graph` itself is represented as a list of `Node` objects. Thus, to - # iterate through all of the operations in our graph, we iterate over each - # `Node` in our `Graph`. - for node in fx_model.graph.nodes: - # The FX IR contains several types of nodes, which generally represent - # call sites to modules, functions, or methods. The type of node is - # determined by `Node.op`. - if ( - node.op != "call_module" - ): # If our current node isn't calling a Module then we can ignore it. +def _fuse_conv_bn_eval(conv, bn): + """ + Given a conv Module `A` and an batch_norm module `B`, returns a conv + module `C` such that C(x) == B(A(x)) in inference mode. + """ + assert not (conv.training or bn.training), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + fused_conv.weight, fused_conv.bias = _fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn.running_mean, + bn.running_var, + bn.eps, + bn.weight, + bn.bias, + ) + + return fused_conv + + +def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): + if conv_b is None: + conv_b = flow.zeros_like(bn_rm) + if bn_w is None: + bn_w = flow.ones_like(bn_rm) + if bn_b is None: + bn_b = flow.zeros_like(bn_rm) + bn_var_rsqrt = flow.rsqrt(bn_rv + bn_eps) + + conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape( + [-1] + [1] * (len(conv_w.shape) - 1) + ) + conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + + return flow.nn.Parameter(conv_w), flow.nn.Parameter(conv_b) + + +def _parent_name(target: str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + + +def _replace_node_module( + node: flow.fx.Node, modules: Dict[str, Any], new_module: flow.nn.Module +): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, new_module) + + +def _fx_fuse(model: flow.nn.Module) -> flow.nn.Module: + model = copy.deepcopy(model) + # The first step of most FX passes is to symbolically trace our model to + # obtain a `GraphModule`. This is a representation of our original model + # that is functionally identical to our original model, except that we now + # also have a graph representation of our forward pass. + fx_model: flow.fx.GraphModule = flow.fx.symbolic_trace(model) + modules = dict(fx_model.named_modules()) + + # The primary representation for working with FX are the `Graph` and the + # `Node`. Each `GraphModule` has a `Graph` associated with it - this + # `Graph` is also what generates `GraphModule.code`. + # The `Graph` itself is represented as a list of `Node` objects. Thus, to + # iterate through all of the operations in our graph, we iterate over each + # `Node` in our `Graph`. + for node in fx_model.graph.nodes: + # The FX IR contains several types of nodes, which generally represent + # call sites to modules, functions, or methods. The type of node is + # determined by `Node.op`. + if ( + node.op != "call_module" + ): # If our current node isn't calling a Module then we can ignore it. + continue + # For call sites, `Node.target` represents the module/function/method + # that's being called. Here, we check `Node.target` to see if it's a + # batch norm module, and then check `Node.args[0].target` to see if the + # input `Node` is a convolution. + if ( + type(modules[node.target]) is nn.BatchNorm2d + and type(modules[node.args[0].target]) is nn.Conv2d + ): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes continue - # For call sites, `Node.target` represents the module/function/method - # that's being called. Here, we check `Node.target` to see if it's a - # batch norm module, and then check `Node.args[0].target` to see if the - # input `Node` is a convolution. - if ( - type(modules[node.target]) is nn.BatchNorm2d - and type(modules[node.args[0].target]) is nn.Conv2d - ): - if len(node.args[0].users) > 1: # Output of conv is used by other nodes - continue - conv = modules[node.args[0].target] - bn = modules[node.target] - fused_conv = TestConvBnFuse.fuse_conv_bn_eval(conv, bn) - TestConvBnFuse.replace_node_module(node.args[0], modules, fused_conv) - # As we've folded the batch nor into the conv, we need to replace all uses - # of the batch norm with the conv. - node.replace_all_uses_with(node.args[0]) - # Now that all uses of the batch norm have been replaced, we can - # safely remove the batch norm. - fx_model.graph.erase_node(node) - fx_model.graph.lint() - # After we've modified our graph, we need to recompile our graph in order - # to keep the generated code in sync. - fx_model.recompile() - return fx_model - + conv = modules[node.args[0].target] + bn = modules[node.target] + fused_conv = _fuse_conv_bn_eval(conv, bn) + _replace_node_module(node.args[0], modules, fused_conv) + # As we've folded the batch nor into the conv, we need to replace all uses + # of the batch norm with the conv. + node.replace_all_uses_with(node.args[0]) + # Now that all uses of the batch norm have been replaced, we can + # safely remove the batch norm. + fx_model.graph.erase_node(node) + fx_model.graph.lint() + # After we've modified our graph, we need to recompile our graph in order + # to keep the generated code in sync. + fx_model.recompile() + return fx_model + + +@flow.unittest.skip_unless_1n1d() +class TestConvBnFuse(flow.unittest.TestCase): def test_fuse(test_case): class WrappedBatchNorm(nn.Module): def __init__(self): @@ -154,7 +160,7 @@ def forward(self, x): model.eval() - fused_model = TestConvBnFuse.fuse(model) + fused_model = _fx_fuse(model) for i in range(10): inp = flow.randn(5, 1, 32, 32) test_case.assertTrue( diff --git a/python/oneflow/test/graph/test_fx_replace_ops.py b/python/oneflow/test/graph/test_fx_replace_ops.py index a3d39dcc406..953229b375f 100644 --- a/python/oneflow/test/graph/test_fx_replace_ops.py +++ b/python/oneflow/test/graph/test_fx_replace_ops.py @@ -31,6 +31,7 @@ def forward(self, x, w1, w2): return x + flow.max(m1) + flow.max(m2) +@flow.unittest.skip_unless_1n1d() class TestReplaceOps(flow.unittest.TestCase): def test_pattern(test_case): traced = symbolic_trace(M()) diff --git a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py index 19a76c85438..33e127e6d26 100644 --- a/python/oneflow/test/graph/test_fx_symbolic_trace_module.py +++ b/python/oneflow/test/graph/test_fx_symbolic_trace_module.py @@ -57,6 +57,7 @@ def forward(self, x: flow.Tensor) -> flow.Tensor: return x +@flow.unittest.skip_unless_1n1d() class TestAlexNet(flow.unittest.TestCase): def test_alexnet(test_case): m = AlexNet() diff --git a/python/oneflow/test/graph/test_graph_pipeline.py b/python/oneflow/test/graph/test_graph_pipeline.py index b1c3128be7e..9fcff61e85e 100644 --- a/python/oneflow/test/graph/test_graph_pipeline.py +++ b/python/oneflow/test/graph/test_graph_pipeline.py @@ -145,6 +145,12 @@ def __init__(self): self.pp_m.stage_1_m.to(GraphModule).set_stage(1) self.pp_m.stage_2_m.to(GraphModule).set_stage(2) self.pp_m.stage_3_m.to(GraphModule).set_stage(3) + + self.pp_m.stage_0_m.to(GraphModule).activation_checkpointing = True + self.pp_m.stage_1_m.to(GraphModule).activation_checkpointing = True + self.pp_m.stage_2_m.to(GraphModule).activation_checkpointing = True + self.pp_m.stage_3_m.to(GraphModule).activation_checkpointing = True + self.mseloss = flow.nn.MSELoss("sum") self.add_optimizer(sgd) # Let graph to do gradient accumulatioin, pipline execution depends on gradient accumulatioin. diff --git a/python/oneflow/test/graph/test_graph_separate_compile.py b/python/oneflow/test/graph/test_graph_separate_compile.py new file mode 100644 index 00000000000..d4c76428d8a --- /dev/null +++ b/python/oneflow/test/graph/test_graph_separate_compile.py @@ -0,0 +1,211 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import contextlib +import os +import numpy as np + +import oneflow as flow +from oneflow import nn +import oneflow.unittest + + +@contextlib.contextmanager +def modified_environ(*remove, **update): + """ + From: https://stackoverflow.com/questions/2059482/temporarily-modify-the-current-processs-environment + Temporarily updates the ``os.environ`` dictionary in-place. + + The ``os.environ`` dictionary is updated in-place so that the modification + is sure to work in all situations. + + :param remove: Environment variables to remove. + :param update: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after] + + +def run_testcase_with_sep_compile(test_case_cls): + new_cls = type("SeparationCompile_" + test_case_cls.__name__, (test_case_cls,), {}) + with modified_environ( + ONEFLOW_LAZY_COMPILE_MODE="rank_per_process", ENABLE_LOGICAL_CHAIN="1" + ): + assert os.environ.get("ONEFLOW_LAZY_COMPILE_MODE") == "rank_per_process" + assert os.environ.get("ENABLE_LOGICAL_CHAIN") == "1" + flow.boxing.nccl.enable_use_compute_stream(True) + unittest.TextTestRunner().run( + unittest.TestLoader().loadTestsFromTestCase(new_cls) + ) + + +def _get_comb1to2d_test(): + class _TestModuleDiffHierarchy(nn.Module): + def forward(self, x): + sbp_1ds = [ + flow.sbp.broadcast, + flow.sbp.partial_sum, + flow.sbp.split(0), + flow.sbp.split(1), + flow.sbp.split(2), + ] + + for sbp1 in sbp_1ds: + + for sbp2 in sbp_1ds: + for sbp3 in sbp_1ds: + # (2, 2) -> 4 + x = x.to_global( + placement=flow.placement( + type="cuda", ranks=np.array(range(4)) + ), + sbp=[sbp1], + ) + # 4 -> (2, 2) + x = x.to_global( + placement=flow.placement( + type="cuda", ranks=np.array(range(4)).reshape(2, 2) + ), + sbp=[sbp2, sbp3], + ) + + return x + + class _TestModuleDiffPlacement(nn.Module): + def forward(self, x): + sbp_1ds = [ + flow.sbp.broadcast, + flow.sbp.partial_sum, + flow.sbp.split(0), + flow.sbp.split(1), + flow.sbp.split(2), + ] + for sbp1 in sbp_1ds: + for sbp2 in sbp_1ds: + for sbp3 in sbp_1ds: + # (2, 2) -> 3 + # 4 is not divisible by 3 + x = x.to_global( + placement=flow.placement( + type="cuda", ranks=np.array(range(3)) + ), + sbp=[sbp1], + ) + # 3 -> (2, 2) + x = x.to_global( + placement=flow.placement( + type="cuda", ranks=np.array(range(4)).reshape(2, 2) + ), + sbp=[sbp2, sbp3], + ) + + return x + + class _TestGraph(nn.Graph): + def __init__(self, model): + super().__init__() + self.model = model + + def build(self, x): + x = self.model(x) + return x + + @flow.unittest.skip_unless_1n4d() + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + class TestSepCompileLazyAllSbpCombinationTesting(flow.unittest.TestCase): + def test_lazy_boxing_2d_all_combination_diff_hierarchy(test_case): + os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" + os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" + + x = flow.ones( + 4, + 12, + 4, + sbp=[flow.sbp.broadcast, flow.sbp.broadcast], + placement=flow.placement( + type="cuda", ranks=np.array(range(4)).reshape(2, 2) + ), + ) + model_diff_hierarchy = _TestModuleDiffHierarchy() + graph_diff_hierarchy = _TestGraph(model_diff_hierarchy) + y = graph_diff_hierarchy(x) + + def test_lazy_boxing_2d_all_combination_diff_placement(test_case): + os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" + os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" + + x = flow.ones( + 4, + 12, + 4, + sbp=[flow.sbp.broadcast, flow.sbp.broadcast], + placement=flow.placement( + type="cuda", ranks=np.array(range(4)).reshape(2, 2) + ), + ) + model_diff_placement = _TestModuleDiffPlacement() + graph_diff_placement = _TestGraph(model_diff_placement) + z = graph_diff_placement(x) + test_case.assertTrue(np.allclose(x.numpy(), z.numpy(), 1e-05, 1e-05)) + + return TestSepCompileLazyAllSbpCombinationTesting + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n4d() +class TestSeparationCompile(oneflow.unittest.TestCase): + def test_test_alexnet_auto_parallel(test_case): + from test_alexnet_auto_parallel import TestAlexnetAutoParallel + + run_testcase_with_sep_compile(TestAlexnetAutoParallel) + + def _test_comb1to2d(test_case): + run_testcase_with_sep_compile(_get_comb1to2d_test()) + + def test_graph_zero(test_case): + from test_graph_zero import TestLinearTrainGraph2DWithZeRO + + run_testcase_with_sep_compile(TestLinearTrainGraph2DWithZeRO) + + def test_graph_clip_grad_norm(test_case): + from test_graph_clip_grad_norm import TestGraphClipGradNorm + + run_testcase_with_sep_compile(TestGraphClipGradNorm) + + def test_graph_pipeline_grad_acc_and_activatioin_checkpointing(test_case): + from test_graph_pipeline import TestGraphPipeline + + run_testcase_with_sep_compile(TestGraphPipeline) + + +if __name__ == "__main__": + unittest.main()